4444
4545(:: ∂⃖{1 })(:: typeof (broadcasted), f, args... ) = split_bc_rule (f, args... )
4646(:: ∂⃖{1 })(:: typeof (broadcasted), f, arg:: Array ) = split_bc_rule (f, arg) # ambiguity
47- function split_bc_rule (f:: F , args... ) where {F}
47+ function split_bc_rule (f:: F , args:: Vararg{Any,N} ) where {F,N }
4848 T = Broadcast. combine_eltypes (f, args)
4949 TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
5050 if eltype (T) == Bool
@@ -71,10 +71,11 @@ function split_bc_rule(f::F, args...) where {F}
7171 dargs = map (unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast?
7272 (NoTangent (), NoTangent (), dargs... )
7373 end
74- return ys, length (args) == 1 ? back_2_one : back_2_many
74+ return ys, N == 1 ? back_2_one : back_2_many
7575 else
7676 # Slow path: collect all the pullbacks & apply them later.
77- # Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
77+ # (Since broadcast makes no guarantee about order of calls, and un-fusing
78+ # can change the number of calls, this does not bother to try to reverse.)
7879 _print (" path 3" )
7980 ys, backs = splitcast (∂⃖ {1} (), f, args... )
8081 function back_3 (dys)
@@ -84,15 +85,21 @@ function split_bc_rule(f::F, args...) where {F}
8485 dargs = map (unbroadcast, args, Base. tail (deltas)) # no real need to close over args here
8586 (NoTangent (), sum (first (deltas)), dargs... )
8687 end
88+ back_3 (:: AbstractZero ) = (NoTangent (), map (Returns (ZeroTangent ()), args)... )
8789 return ys, back_3
8890 end
8991end
9092
91- # This uses "mulltimap"-like constructs:
93+ # Skip AD'ing through the axis computation
94+ function (:: ∂⃖{1 })(:: typeof (Base. Broadcast. instantiate), bc:: Base.Broadcast.Broadcasted )
95+ uninstantiate (Δ) = Core. tuple (NoTangent (), Δ)
96+ return Base. Broadcast. instantiate (bc), uninstantiate
97+ end
98+
99+ # This uses "multimap"-like constructs:
92100
93101using StructArrays
94102splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... )))
95- # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
96103splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
97104
98105#=
156163(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = split_bc_plus (args... )
157164(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), arg:: Array ) = split_bc_plus (arg) # ambiguity
158165function split_bc_plus (xs... ) where {F}
159- broadcasted (+ , xs... ), Δ -> let Δun = unthunk (Δ )
166+ broadcasted (+ , xs... ), Δraw -> let Δ = unthunk (Δraw )
160167 _print (" broadcast +" )
161- (NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δun ), xs)... )
168+ (NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δ ), xs)... )
162169 end
163170end
164171Base. eltype (bc:: Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple} ) =
@@ -167,20 +174,20 @@ Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) =
167174(:: ∂⃖{1 })(:: typeof (copy), bc:: Broadcast.Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
168175
169176function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x, y)
170- broadcasted (- , x, y), Δ -> let Δun = unthunk (Δ )
177+ broadcasted (- , x, y), Δraw -> let Δ = unthunk (Δraw )
171178 _print (" broadcast -" )
172- (NoTangent (), NoTangent (), unbroadcast (x, Δun ), - unbroadcast (y, Δun ))
179+ (NoTangent (), NoTangent (), unbroadcast (x, Δ ), - unbroadcast (y, Δ ))
173180 # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
174181 end
175182end
176183
177184using LinearAlgebra: dot
178185
179186function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x, y) # should this be vararg, or will laziness handle it?
180- broadcasted (* , x, y), Δ -> let Δun = unthunk (Δ )
187+ broadcasted (* , x, y), Δraw -> let Δ = unthunk (Δraw )
181188 _print (" broadcast *" )
182- dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δun ) : unbroadcast (x, Δun .* conj .(y))
183- dy = eltype (y)== Bool ? NoTangent () : y isa Number ? dot (x, Δun ) : unbroadcast (y, Δun .* conj .(x))
189+ dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δ ) : unbroadcast (x, Δ .* conj .(y))
190+ dy = eltype (y)== Bool ? NoTangent () : y isa Number ? dot (x, Δ ) : unbroadcast (y, Δ .* conj .(x))
184191 # When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
185192 # Will things like this work? Ref([1,2]) .* [1,2,3]
186193 (NoTangent (), NoTangent (), dx, dy)
0 commit comments