3434using ChainRulesCore: derivatives_given_output
3535
3636# Broadcast over one element is just map
37- function (∂⃖ₙ: :∂⃖ {N})(:: typeof (broadcasted), f, a:: Array ) where {N}
38- ∂⃖ₙ (map, f, a)
39- end
37+ # function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
38+ # ∂⃖ₙ(map, f, a)
39+ # end
40+
41+ (:: ∂⃖{1 })(:: typeof (copy), bc:: Broadcast.Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
4042
41- (:: ∂⃖{1 })(:: typeof (broadcasted), f, args... ) = split_bc_rule (f, args... )
42- (:: ∂⃖{1 })(:: typeof (broadcasted), f, arg:: Array ) = split_bc_rule (f, arg) # ambiguity
43+ (:: ∂⃖{1 })(:: typeof (broadcasted), f:: F , args... ) where {F} = split_bc_rule (f, args... )
44+ # (::∂⃖{1})(::typeof(broadcasted), f::F , arg::Array) where {F} = split_bc_rule(f, arg) # ambiguity
4345function split_bc_rule (f:: F , args:: Vararg{Any,N} ) where {F,N}
4446 T = Broadcast. combine_eltypes (f, args)
4547 TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
@@ -48,17 +50,17 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
4850 back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
4951 return f .(args... ), back_1
5052 elseif T <: Number && isconcretetype (TΔ)
51- # Fast path: just broadcast, and use x & y to find derivative .
53+ # Fast path: just broadcast, and use arguments & result to find derivatives .
5254 ys = f .(args... )
5355 function back_2_one (dys) # For f.(x) we do not need StructArrays / unzip at all
5456 delta = broadcast (unthunk (dys), ys, args... ) do dy, y, a
5557 das = only (derivatives_given_output (y, f, a))
56- dy * conj (only (das))
58+ dy * conj (only (das)) # possibly this * should be made nan-safe.
5759 end
5860 (NoTangent (), NoTangent (), unbroadcast (only (args), delta))
5961 end
6062 function back_2_many (dys)
61- deltas = splitcast (unthunk (dys), ys, args... ) do dy, y, as...
63+ deltas = tuplecast (unthunk (dys), ys, args... ) do dy, y, as...
6264 das = only (derivatives_given_output (y, f, as... ))
6365 map (da -> dy * conj (da), das)
6466 end
@@ -70,62 +72,76 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
7072 # Slow path: collect all the pullbacks & apply them later.
7173 # (Since broadcast makes no guarantee about order of calls, and un-fusing
7274 # can change the number of calls, this does not bother to try to reverse.)
73- ys , backs = splitcast (∂⃖ {1} (), f, args... )
75+ ys3 , backs = tuplecast (∂⃖ {1} (), f, args... )
7476 function back_3 (dys)
75- deltas = splitmap (backs, unthunk (dys)) do back, dy
77+ deltas = tuplecast (backs, unthunk (dys)) do back, dy # could be map, sizes match
7678 map (unthunk, back (dy))
7779 end
78- dargs = map (unbroadcast, args, Base. tail (deltas)) # no real need to close over args here
80+ dargs = map (unbroadcast, args, Base. tail (deltas))
7981 (NoTangent (), sum (first (deltas)), dargs... )
8082 end
8183 back_3 (:: AbstractZero ) = (NoTangent (), map (Returns (ZeroTangent ()), args)... )
82- return ys , back_3
84+ return ys3 , back_3
8385 end
8486end
8587
88+ # Don't run broadcasting on scalars
89+ function split_bc_rule (f:: F , args:: Number... ) where {F}
90+ z, back = ∂⃖ {1} ()(f, args... )
91+ z, dz -> (NoTangent (), back (dz)... )
92+ end
93+
94+ split_bc_rule (:: typeof (identity), x) = x, Δ -> (NoTangent (), NoTangent (), Δ)
95+ split_bc_rule (:: typeof (identity), x:: Number ) = x, Δ -> (NoTangent (), NoTangent (), Δ)
96+
8697# Skip AD'ing through the axis computation
8798function (:: ∂⃖{1 })(:: typeof (Base. Broadcast. instantiate), bc:: Base.Broadcast.Broadcasted )
8899 uninstantiate (Δ) = Core. tuple (NoTangent (), Δ)
89100 return Base. Broadcast. instantiate (bc), uninstantiate
90101end
91102
92- # This uses "multimap"-like constructs:
93103using StructArrays
94- splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... )))
95- splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
104+
105+ function tuplecast (f:: F , args... ) where {F}
106+ T = Broadcast. combine_eltypes (f, args)
107+ if isconcretetype (T)
108+ T <: Tuple || throw (ArgumentError (" tuplecast(f, args) only works on functions returning a tuple." ))
109+ end
110+ bc = Broadcast. instantiate (Broadcast. broadcasted (f, args... ))
111+ StructArrays. components (StructArray (bc))
112+ end
96113
97114# For certain cheap operations we can easily allow fused broadcast:
115+ const NumericOrBroadcast = Union{Number, AbstractArray{<: Number }, NTuple{<: Any ,Number}, Broadcast. Broadcasted}
98116
99- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = lazy_bc_plus (args... )
100- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), arg :: Array ) = lazy_bc_plus (arg) # ambiguity
117+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args:: NumericOrBroadcast ... ) = lazy_bc_plus (args... )
118+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args :: Number ) = split_bc_rule ( + , args ... )
101119function lazy_bc_plus (xs... ) where {F}
102120 broadcasted (+ , xs... ), Δraw -> let Δ = unthunk (Δraw)
103121 (NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δ), xs)... )
104122 end
105123end
106124
107- (:: ∂⃖{1 })(:: typeof (copy), bc:: Broadcast.Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
108-
109- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x, y)
125+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x:: Number , y:: Number ) = split_bc_rule (- , x, y)
126+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
110127 broadcasted (- , x, y), Δraw -> let Δ = unthunk (Δraw)
111128 (NoTangent (), NoTangent (), unbroadcast (x, Δ), - unbroadcast (y, Δ))
112- # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
113129 end
114130end
115131
116132using LinearAlgebra: dot
117- const Numeric{T<: Number } = Union{T, AbstractArray{T}}
118133
119- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x:: Numeric , y:: Numeric )
134+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x:: Number , y:: Number ) = split_bc_rule (* , x, y)
135+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
120136 broadcasted (* , x, y), Δraw -> let Δ = unthunk (Δraw)
121- dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δ) : unbroadcast (x, Δ .* conj .(y))
122- dy = eltype (y)== Bool ? NoTangent () : y isa Number ? dot (x, Δ) : unbroadcast (y, Δ .* conj .(x))
123- # When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
124- (NoTangent (), NoTangent (), dx, dy)
137+ (NoTangent (), NoTangent (), _back_star (x, y, Δ), _back_star (y, x, Δ))
125138 end
126139end
140+ _back_star (x, y, Δ) = unbroadcast (x, Δ .* conj .(y))
141+ _back_star (x:: Number , y, Δ) = dot (y, Δ)
142+ _back_star (x:: Bool , y, Δ) = NoTangent ()
127143
128- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x, :: Val{2} )
144+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: NumericOrBroadcast , :: Val{2} )
129145 broadcasted (* , x, x), Δ -> begin
130146 dx = unbroadcast (x, 2 .* unthunk (Δ) .* conj .(x))
131147 (NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
@@ -135,41 +151,40 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::type
135151 x^ 2 , Δ -> (NoTangent (), NoTangent (), NoTangent (), 2 * Δ * conj (x), NoTangent ())
136152end
137153
138- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x:: Numeric , y:: Number )
139- z, back = ∂⃖ {1} ()(/ , x, y)
140- z, dz -> begin
141- _, dx, dy = back (dz)
154+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x:: Number , y:: Number ) = split_bc_rule (/ , x, y)
155+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x:: NumericOrBroadcast , y:: Number )
156+ z = broadcast (/ , x, y)
157+ z, Δth -> let Δ = unthunk (Δth)
158+ dx = unbroadcast (x, Δ ./ conj .(y))
159+ dy = - dot (z, Δ) / (conj (y)) # the reason to be eager is to allow dot here
142160 (NoTangent (), NoTangent (), dx, dy)
143161 end
144162end
145163
146- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (identity), x) = x, identity_pullback
147- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (identity), x:: Array ) = x, identity_pullback # ambiguity
148- identity_pullback (Δ) = (NoTangent (), NoTangent (), Δ)
164+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (identity), x) = split_bc_rule (identity, x)
165+ # (::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = split_bc_rule(identity, x) # ambiguity
149166
150- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: AbstractArray{Real} ) = x, identity_pullback
151- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: Array{Real} ) = x, identity_pullback
167+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: AbstractArray{Real} ) = split_bc_rule (identity, x)
168+ # (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = split_bc_rule(identity, x) # ambiguity
152169(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x) =
153170 broadcasted (conj, x), Δ -> (NoTangent (), conj (unthunk (Δ)))
154171(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: Array ) =
155172 broadcasted (conj, x), Δ -> (NoTangent (), conj (unthunk (Δ)))
156173
157- # All broadcasts use `unbroadcast` to reduce to correct shape:
158-
174+ # Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape:
159175function unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx)
160176 N = ndims (dx)
161177 if length (x) == length (dx)
162178 ProjectTo (x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
163179 else
164- dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N) # awful hack to get type-stable `dims`
165- ProjectTo (x)(sum (dx; dims))
180+ dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N) # hack to get type-stable `dims`
181+ ProjectTo (x)(sum (dx; dims)) # ideally this sum might be thunked?
166182 end
167183end
168184unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx:: AbstractZero ) = dx
169185
170186unbroadcast (x:: T , dx) where {T<: Tuple{Any} } = ProjectTo (x)(Tangent {T} (sum (dx)))
171187function unbroadcast (x:: T , dx) where {T<: Tuple{Vararg{Any,N}} } where {N}
172- _print (" unbroadcast tuple" )
173188 val = if length (x) == length (dx)
174189 dx
175190 else
0 commit comments