@@ -29,28 +29,30 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
2929 return r
3030end
3131
32- _print (s) = nothing
33- # _print(s) = printstyled(s, "\n"; color=:magenta)
32+ # Reverse mode broadcast rules
33+
34+ using ChainRulesCore: derivatives_given_output
35+
36+ # _print(s) = nothing
37+ _print (s) = printstyled (s, " \n " ; color= :magenta )
3438
3539# Broadcast over one element is just map
3640function (∂⃖ₙ: :∂⃖ {N})(:: typeof (broadcasted), f, a:: Array ) where {N}
3741 _print (" path 0" )
3842 ∂⃖ₙ (map, f, a)
3943end
4044
41- using ChainRulesCore: derivatives_given_output
42-
4345(:: ∂⃖{1 })(:: typeof (broadcasted), f, args... ) = split_bc_rule (f, args... )
4446(:: ∂⃖{1 })(:: typeof (broadcasted), f, arg:: Array ) = split_bc_rule (f, arg) # ambiguity
4547function split_bc_rule (f:: F , args... ) where {F}
4648 T = Broadcast. combine_eltypes (f, args)
47- if T == Bool
48- # Trivial case
49+ TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
50+ if eltype (T) == Bool
51+ # Trivial case: non-differentiable output
4952 _print (" path 1" )
5053 back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
5154 return f .(args... ), back_1
52- elseif isconcretetype (Core. Compiler. _return_type (
53- derivatives_given_output, Tuple{T, F, map (eltype, args)... }))
55+ elseif T <: Number && isconcretetype (TΔ)
5456 # Fast path: just broadcast, and use x & y to find derivative.
5557 ys = f .(args... )
5658 _print (" path 2" )
@@ -65,8 +67,9 @@ function split_bc_rule(f::F, args...) where {F}
6567 return ys, back_2
6668 else
6769 # Slow path: collect all the pullbacks & apply them later.
70+ # Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
6871 _print (" path 3" )
69- ys, backs = splitcast (rrule_via_ad, DiffractorRuleConfig (), f, args... )
72+ ys, backs = splitcast (∂⃖ {1} (), f, args... )
7073 function back_3 (dys)
7174 deltas = splitmap (backs, unthunk (dys)) do back, dy
7275 map (unthunk, back (dy))
@@ -78,8 +81,11 @@ function split_bc_rule(f::F, args...) where {F}
7881 end
7982end
8083
84+ # This uses "mulltimap"-like constructs:
85+
8186using StructArrays
82- splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... ))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
87+ splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... )))
88+ # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
8389splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
8490
8591# For certain cheap operations we can easily allow fused broadcast:
107113
108114using LinearAlgebra: dot
109115
110- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x, y)
116+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x, y) # should this be vararg, or will laziness handle it?
111117 broadcasted (* , x, y), Δ -> let Δun = unthunk (Δ)
112118 _print (" broadcast *" )
113119 dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δun) : unbroadcast (x, Δun .* conj .(y))
@@ -117,41 +123,88 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y)
117123 (NoTangent (), NoTangent (), dx, dy)
118124 end
119125end
126+ # Alternative to `x isa Number` etc above... but not quite right!
127+ # (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y::Number) = rrule_via_ad(DiffractorRuleConfig(), *, x, y)
128+
129+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x, :: Val{2} )
130+ _print (" broadcast ^2" )
131+ broadcasted (* , x, x), Δ -> begin
132+ dx = unbroadcast (x, 2 .* Δ .* conj .(x))
133+ (NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
134+ end
135+ end
136+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Number , :: Val{2} )
137+ _print (" simple ^2" )
138+ x^ 2 , Δ -> (NoTangent (), NoTangent (), NoTangent (), 2 * Δ * conj (x), NoTangent ())
139+ end
140+
141+ # function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y) # not obvious whether this is better than automatic
142+ # broadcasted(/, x, y), Δ -> let Δun = unthunk(Δ)
143+ # _print("broadcast /")
144+ # dx = unbroadcast(x, Δ ./ conj.(y))
145+ # dy = unbroadcast(y, .-Δ .* conj.(res ./ y))
146+ # (NoTangent(), NoTangent(), dx, dy)
147+ # end
148+ # end
149+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x, y:: Number )
150+ _print (" simple /" )
151+ z, back = ∂⃖ {1} ()(/ , x, y)
152+ z, Δ -> begin
153+ _, dx, dy = back (Δ)
154+ (NoTangent (), NoTangent (), dx, dy) # maybe there should be a funciton for this? Use for conj, identity too
155+ end
156+ end
120157
121158(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x) =
122159 broadcasted (conj, x), Δ -> (NoTangent (), conj (unthunk (Δ)))
123160(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: AbstractArray{Real} ) =
124161 x, Δ -> (NoTangent (), Δ)
125162
163+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (identity), x) =
164+ x, Δ -> (NoTangent (), Δ)
165+
166+ # All broadcasts use `unbroadcast` to reduce to correct shape:
167+
126168function unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx)
127169 N = ndims (dx)
128170 if length (x) == length (dx)
129171 ProjectTo (x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
130172 else
131- # This is an awful hack to get type-stable `dims`
132- dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N)
173+ dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N) # awful hack to get type-stable `dims`
133174 ProjectTo (x)(sum (dx; dims))
134175 end
135176end
136177unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx:: NoTangent ) = NoTangent ()
137178
179+ unbroadcast (x:: T , dx) where {T<: Tuple{Any} } = ProjectTo (x)(Tangent {T} (sum (dx)))
180+ function unbroadcast (x:: T , dx) where {T<: Tuple{Vararg{Any,N}} } where {N}
181+ _print (" unbroadcast tuple" )
182+ val = if length (x) == length (dx)
183+ dx
184+ else
185+ sum (dx; dims= 2 : ndims (dx))
186+ end
187+ ProjectTo (x)(NTuple {length(x)} (val)) # Tangent
188+ end
189+
190+ unbroadcast (f:: Function , df) = sum (df)
138191unbroadcast (x:: Number , dx) = ProjectTo (x)(sum (dx))
139- unbroadcast (f:: Function , df) = ProjectTo (x)(sum (df))
140192unbroadcast (x:: Base.RefValue , dx) = ProjectTo (x)(Ref (sum (dx)))
141193
142194unbroadcast (:: Bool , dx) = NoTangent ()
143195unbroadcast (:: AbstractArray{Bool} , dx) = NoTangent ()
144196unbroadcast (:: AbstractArray{Bool} , :: NoTangent ) = NoTangent () # ambiguity
145197unbroadcast (:: Val , dx) = NoTangent ()
146- # Maybe more non-diff types? Some fallback?
147198
148- unbroadcast (x:: T , dx) where {T<: Tuple{Any} } = ProjectTo (x)(Tangent {T} (sum (dx)))
149- function unbroadcast (x:: T , dx) where {T<: Tuple{Vararg{Any,N}} } where {N}
150- _print (" unbroadcast tuple" )
151- val = if length (x) == length (dx)
152- dx
199+ function unbroadcast (x, dx)
200+ p = ProjectTo (x)
201+ if dx isa AbstractZero || p isa ProjectTo{<: AbstractZero }
202+ return NoTangent ()
203+ end
204+ b = Broadcast. broadcastable (x)
205+ if b isa Ref # then x is scalar under broadcast
206+ return p (sum (dx))
153207 else
154- sum (dx; dims = 2 : ndims (dx) )
208+ error ( " don't know how to handle broadcast gradient for x:: $( typeof (x)) " )
155209 end
156- ProjectTo (x)(NTuple {length(x)} (val)) # Tangent
157210end
0 commit comments