3333
3434using ChainRulesCore: derivatives_given_output
3535
36- _print (s) = nothing
37- # _print(s) = printstyled(s, "\n"; color=:magenta)
38-
3936# Broadcast over one element is just map
4037function (∂⃖ₙ: :∂⃖ {N})(:: typeof (broadcasted), f, a:: Array ) where {N}
41- _print (" path 0, order $N " )
4238 ∂⃖ₙ (map, f, a)
4339end
4440
@@ -49,13 +45,11 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
4945 TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
5046 if T === Bool
5147 # Trivial case: non-differentiable output, e.g. `x .> 0`
52- _print (" path 1" )
5348 back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
5449 return f .(args... ), back_1
5550 elseif T <: Number && isconcretetype (TΔ)
5651 # Fast path: just broadcast, and use x & y to find derivative.
5752 ys = f .(args... )
58- _print (" path 2" )
5953 function back_2_one (dys) # For f.(x) we do not need StructArrays / unzip at all
6054 delta = broadcast (unthunk (dys), ys, args... ) do dy, y, a
6155 das = only (derivatives_given_output (y, f, a))
@@ -76,7 +70,6 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
7670 # Slow path: collect all the pullbacks & apply them later.
7771 # (Since broadcast makes no guarantee about order of calls, and un-fusing
7872 # can change the number of calls, this does not bother to try to reverse.)
79- _print (" path 3" )
8073 ys, backs = splitcast (∂⃖ {1} (), f, args... )
8174 function back_3 (dys)
8275 deltas = splitmap (backs, unthunk (dys)) do back, dy
@@ -97,74 +90,16 @@ function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.
9790end
9891
9992# This uses "multimap"-like constructs:
100-
10193using StructArrays
10294splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... )))
10395splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
10496
105- #=
106- # This is how you could handle CuArrays, route them to unzip(map(...)) fallback path.
107- # Maybe 2nd derivatives too, to avoid writing a gradient for splitcast, rule for unzip is easy.
108-
109- function Diffractor.splitmap(f, args...)
110- if any(a -> a isa CuArray, args)
111- Diffractor._print("unzip splitmap")
112- unzip(map(f, args...))
113- else
114- StructArrays.components(StructArray(Iterators.map(f, args...)))
115- end
116- end
117- function Diffractor.splitcast(f, args...)
118- if any(a -> a isa CuArray, args)
119- Diffractor._print("unzip splitcast")
120- unzip(broadcast(f, args...))
121- else
122- StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
123- end
124- end
125-
126- gradient(x -> sum(log.(x) .+ x'), cu([1,2,3]))[1]
127- gradient(x -> sum(sqrt.(atan.(x, x'))), cu([1,2,3]))[1]
128-
129- =#
130-
131- function unzip (xs:: AbstractArray )
132- x1 = first (xs)
133- x1 isa Tuple || throw (ArgumentError (" unzip only accepts arrays of tuples" ))
134- N = length (x1)
135- unzip (xs, Val (N)) # like Zygote's unzip
136- end
137- @generated function unzip (xs, :: Val{N} ) where {N}
138- each = [:(map ($ (Get (i)), xs)) for i in 1 : N]
139- Expr (:tuple , each... )
140- end
141- unzip (xs:: AbstractArray{Tuple{T}} ) where {T} = (reinterpret (T, xs),) # best case, no copy
142- @generated function unzip (xs:: AbstractArray{Ts} ) where {Ts<: Tuple }
143- each = if count (! Base. issingletontype, Ts. parameters) < 2
144- # good case, no copy of data, some trivial arrays
145- [Base. issingletontype (T) ? :(similar (xs, $ T)) : :(reinterpret ($ T, xs)) for T in Ts. parameters]
146- else
147- [:(map ($ (Get (i)), xs)) for i in 1 : length (fieldnames (Ts))]
148- end
149- Expr (:tuple , each... )
150- end
151-
152- struct Get{i} end
153- Get (i) = Get {Int(i)} ()
154- (:: Get{i} )(x) where {i} = x[i]
155-
156- function ChainRulesCore. rrule (:: typeof (unzip), xs:: AbstractArray )
157- rezip (dy) = (NoTangent (), tuple .(unthunk (dy)... ))
158- return unzip (xs), rezip
159- end
160-
16197# For certain cheap operations we can easily allow fused broadcast:
16298
16399(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = lazy_bc_plus (args... )
164100(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), arg:: Array ) = lazy_bc_plus (arg) # ambiguity
165101function lazy_bc_plus (xs... ) where {F}
166102 broadcasted (+ , xs... ), Δraw -> let Δ = unthunk (Δraw)
167- _print (" broadcast +" )
168103 (NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δ), xs)... )
169104 end
170105end
173108
174109function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x, y)
175110 broadcasted (- , x, y), Δraw -> let Δ = unthunk (Δraw)
176- _print (" broadcast -" )
177111 (NoTangent (), NoTangent (), unbroadcast (x, Δ), - unbroadcast (y, Δ))
178112 # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
179113 end
@@ -184,7 +118,6 @@ const Numeric{T<:Number} = Union{T, AbstractArray{T}}
184118
185119function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x:: Numeric , y:: Numeric )
186120 broadcasted (* , x, y), Δraw -> let Δ = unthunk (Δraw)
187- _print (" broadcast *" )
188121 dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δ) : unbroadcast (x, Δ .* conj .(y))
189122 dy = eltype (y)== Bool ? NoTangent () : y isa Number ? dot (x, Δ) : unbroadcast (y, Δ .* conj .(x))
190123 # When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
@@ -193,19 +126,16 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeri
193126end
194127
195128function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x, :: Val{2} )
196- _print (" broadcast ^2" )
197129 broadcasted (* , x, x), Δ -> begin
198130 dx = unbroadcast (x, 2 .* unthunk (Δ) .* conj .(x))
199131 (NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
200132 end
201133end
202134function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Number , :: Val{2} )
203- _print (" simple ^2" )
204135 x^ 2 , Δ -> (NoTangent (), NoTangent (), NoTangent (), 2 * Δ * conj (x), NoTangent ())
205136end
206137
207138function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x:: Numeric , y:: Number )
208- _print (" simple /" )
209139 z, back = ∂⃖ {1} ()(/ , x, y)
210140 z, dz -> begin
211141 _, dx, dy = back (dz)
0 commit comments