@@ -56,15 +56,22 @@ function split_bc_rule(f::F, args...) where {F}
5656 # Fast path: just broadcast, and use x & y to find derivative.
5757 ys = f .(args... )
5858 _print (" path 2" )
59- function back_2 (dys)
59+ function back_2_one (dys) # For f.(x) we do not need StructArrays / unzip at all
60+ delta = broadcast (unthunk (dys), ys, args... ) do dy, y, a
61+ das = only (derivatives_given_output (y, f, a))
62+ dy * conj (only (das))
63+ end
64+ (NoTangent (), NoTangent (), unbroadcast (only (args), delta))
65+ end
66+ function back_2_many (dys)
6067 deltas = splitcast (unthunk (dys), ys, args... ) do dy, y, as...
6168 das = only (derivatives_given_output (y, f, as... ))
6269 map (da -> dy * conj (da), das)
6370 end
64- dargs = map (unbroadcast, args, deltas)
71+ dargs = map (unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast?
6572 (NoTangent (), NoTangent (), dargs... )
6673 end
67- return ys, back_2
74+ return ys, length (args) == 1 ? back_2_one : back_2_many
6875 else
6976 # Slow path: collect all the pullbacks & apply them later.
7077 # Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
@@ -88,6 +95,62 @@ splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args
8895# warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
8996splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
9097
98+ #=
99+ # This is how you could handle CuArrays, route them to unzip(map(...)) fallback path.
100+ # Maybe 2nd derivatives too, to avoid writing a gradient for splitcast, rule for unzip is easy.
101+
102+ function Diffractor.splitmap(f, args...)
103+ if any(a -> a isa CuArray, args)
104+ Diffractor._print("unzip splitmap")
105+ unzip(map(f, args...))
106+ else
107+ StructArrays.components(StructArray(Iterators.map(f, args...)))
108+ end
109+ end
110+ function Diffractor.splitcast(f, args...)
111+ if any(a -> a isa CuArray, args)
112+ Diffractor._print("unzip splitcast")
113+ unzip(broadcast(f, args...))
114+ else
115+ StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
116+ end
117+ end
118+
119+ gradient(x -> sum(log.(x) .+ x'), cu([1,2,3]))[1]
120+ gradient(x -> sum(sqrt.(atan.(x, x'))), cu([1,2,3]))[1]
121+
122+ =#
123+
124+ function unzip (xs:: AbstractArray )
125+ x1 = first (xs)
126+ x1 isa Tuple || throw (ArgumentError (" unzip only accepts arrays of tuples" ))
127+ N = length (x1)
128+ unzip (xs, Val (N)) # like Zygote's unzip
129+ end
130+ @generated function unzip (xs, :: Val{N} ) where {N}
131+ each = [:(map ($ (Get (i)), xs)) for i in 1 : N]
132+ Expr (:tuple , each... )
133+ end
134+ unzip (xs:: AbstractArray{Tuple{T}} ) where {T} = (reinterpret (T, xs),) # best case, no copy
135+ @generated function unzip (xs:: AbstractArray{Ts} ) where {Ts<: Tuple }
136+ each = if count (! Base. issingletontype, Ts. parameters) < 2
137+ # good case, no copy of data, some trivial arrays
138+ [Base. issingletontype (T) ? :(similar (xs, $ T)) : :(reinterpret ($ T, xs)) for T in Ts. parameters]
139+ else
140+ [:(map ($ (Get (i)), xs)) for i in 1 : length (fieldnames (Ts))]
141+ end
142+ Expr (:tuple , each... )
143+ end
144+
145+ struct Get{i} end
146+ Get (i) = Get {Int(i)} ()
147+ (:: Get{i} )(x) where {i} = x[i]
148+
149+ function ChainRulesCore. rrule (:: typeof (unzip), xs:: AbstractArray )
150+ rezip (dy) = (NoTangent (), tuple .(unthunk (dy)... ))
151+ return unzip (xs), rezip
152+ end
153+
91154# For certain cheap operations we can easily allow fused broadcast:
92155
93156(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = split_bc_plus (args... )
0 commit comments