@@ -126,36 +126,38 @@ end
126126
127127@generated function zero_tangent (primal)
128128 fieldcount (primal) == 0 && return NoTangent () # no tangent space at all, no need for structural zero.
129- zfield_exprs = map (fieldnames (primal)) do fname
130- fval = :(
131- if isdefined (primal, $ (QuoteNode (fname)))
132- zero_tangent (getfield (primal, $ (QuoteNode (fname))))
133- else
134- # This is going to be potentially bad, but that's what they get for not giving us a primal
135- # This will never me mutated inplace, rather it will alway be replaced with an actual value first
136- ZeroTangent ()
137- end
138- )
139- Expr (:kw , fname, fval)
140- end
141-
129+
142130 # easy case exit early, can't hold references, can't be a reference.
143131 if isbitstype (primal)
132+ zfield_exprs = map (fieldnames (primal)) do fname
133+ fval = :(zero_tangent (getfield (primal, $ (QuoteNode (fname)))))
134+ Expr (:kw , fname, fval)
135+ end
144136 return :($ Tangent {$primal} ($ (Expr (:parameters , zfield_exprs... ))))
145137 end
146138
147- # hard case need to be prepared for cycic references to this, or that are contained within this
139+ # hard case need to be prepared for references to this, or that are contained within this
148140 quote
149- counts = $ count_references! (primal)
141+ counts = $ count_references (primal)
142+ any_mask = $ (Expr (:tuple , Expr (:parameters , map (fieldnames (primal), fieldtypes (primal)) do fname, ftype
143+ # If it is is unassigned, or if it doesn't have a concrete type, or we have multiple reference to it
144+ # then let it take any value for its tangent
145+ fdef = :(
146+ ! isdefined (primal, $ (QuoteNode (fname))) ||
147+ ! isconcretetype ($ ftype) ||
148+ get (counts, $ (QuoteNode (fname)), 0 ) > 1
149+ )
150+ Expr (:kw , fname, fdef)
151+ end ... )))
152+
153+ # Construct tangents
154+
155+ # Go back and fill in tangents that were not ready
150156 end
151157
152158# # TODO rewrite below
153159 has_mutable_tangent (primal)
154- any_mask = map (fieldnames (primal), fieldtypes (primal)) do fname, ftype
155- # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
156- fdef = :(! isdefined (primal, $ (QuoteNode (fname))) || ! isconcretetype ($ ftype))
157- Expr (:kw , fname, fdef)
158- end
160+ any_mask =
159161 :($ MutableTangent {$primal} (
160162 $ (Expr (:tuple , Expr (:parameters , any_mask... ))),
161163 $ (Expr (:tuple , Expr (:parameters , zfield_exprs... ))),
@@ -184,7 +186,7 @@ function zero_tangent(x::Array{P,N}) where {P,N}
184186end
185187
186188# ##############################################
187- count_references! (x) = count_references (IdDict {Any, Int} (), x)
189+ count_references (x) = count_references (IdDict {Any, Int} (), x)
188190function count_references! (counts:: IdDict{Any, Int} , x)
189191 isbits (x) && return counts # can't be a refernece and can't hold a reference
190192 counts[x] = get (counts, x, 0 ) + 1 # Increment *before* recursing
0 commit comments