@@ -156,7 +156,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
156156 Δs = [Symbol (string (:Δ , i)) for i in 1 : n_inputs]
157157 pushforward_returns = map (1 : n_outputs) do output_i
158158 ∂s = partials[output_i]. args
159- propagation_expr (𝒟, Δs, ∂s)
159+ frule_propagation_expr (𝒟, Δs, ∂s)
160160 end
161161 if n_outputs > 1
162162 # For forward-mode we only return a tuple if output actually a tuple.
@@ -193,7 +193,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
193193 # 1 partial derivative per input
194194 pullback_returns = map (1 : n_inputs) do input_i
195195 ∂s = [partial. args[input_i] for partial in partials]
196- propagation_expr (𝒟, Δs, ∂s)
196+ rrule_propagation_expr (𝒟, Δs, ∂s)
197197 end
198198
199199 pullback = quote
@@ -222,56 +222,16 @@ end
222222 if it is taken at `1+1im` it returns `Complex{Int}`.
223223 At present it is ignored for non-Wirtinger derivatives.
224224"""
225- function propagation_expr (𝒟, Δs, ∂s)
226- wirtinger_indices = findall (∂s) do ex
227- Meta. isexpr (ex, :call ) && ex. args[1 ] === :Wirtinger
228- end
225+ function frule_propagation_expr (𝒟, Δs, ∂s)
229226 ∂s = map (esc, ∂s)
230- if isempty (wirtinger_indices)
231- return standard_propagation_expr (Δs, ∂s)
232- else
233- return wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
234- end
235- end
236-
237- function standard_propagation_expr (Δs, ∂s)
238- # This is basically Δs ⋅ ∂s
239-
240- # Notice: the thunking of `∂s[i] (potentially) saves us some computation
241- # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
242- # as the pullback is evaluated
243- ∂_mul_Δs = [:(@thunk ($ (∂s[i])) * $ (Δs[i])) for i in 1 : length (∂s)]
244- return :(+ ($ (∂_mul_Δs... )))
227+ ∂_mul_Δs = [:(chain (@thunk ($ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
228+ return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
245229end
246230
247- function wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
248- ∂_mul_Δs_primal = Any[]
249- ∂_mul_Δs_conjugate = Any[]
250- ∂_wirtinger_defs = Any[]
251- for i in 1 : length (∂s)
252- if i in wirtinger_indices
253- Δi = Δs[i]
254- ∂i = Symbol (string (:∂ , i))
255- push! (∂_wirtinger_defs, :($ ∂i = $ (∂s[i])))
256- ∂f∂i_mul_Δ = :(wirtinger_primal ($ ∂i) * wirtinger_primal ($ Δi))
257- ∂f∂ī_mul_Δ̄ = :(conj (wirtinger_conjugate ($ ∂i)) * wirtinger_conjugate ($ Δi))
258- ∂f̄∂i_mul_Δ = :(wirtinger_conjugate ($ ∂i) * wirtinger_primal ($ Δi))
259- ∂f̄∂ī_mul_Δ̄ = :(conj (wirtinger_primal ($ ∂i)) * wirtinger_conjugate ($ Δi))
260- push! (∂_mul_Δs_primal, :($ ∂f∂i_mul_Δ + $ ∂f∂ī_mul_Δ̄))
261- push! (∂_mul_Δs_conjugate, :($ ∂f̄∂i_mul_Δ + $ ∂f̄∂ī_mul_Δ̄))
262- else
263- ∂_mul_Δ = :(@thunk ($ (∂s[i])) * $ (Δs[i]))
264- push! (∂_mul_Δs_primal, ∂_mul_Δ)
265- push! (∂_mul_Δs_conjugate, ∂_mul_Δ)
266- end
267- end
268- primal_sum = :(+ ($ (∂_mul_Δs_primal... )))
269- conjugate_sum = :(+ ($ (∂_mul_Δs_conjugate... )))
270- return quote # This will be a block, so will have value equal to last statement
271- $ (∂_wirtinger_defs... )
272- w = Wirtinger ($ primal_sum, $ conjugate_sum)
273- refine_differential ($ 𝒟, w)
274- end
231+ function rrule_propagation_expr (𝒟, Δs, ∂s)
232+ ∂s = map (esc, ∂s)
233+ ∂_mul_Δs = [:(chain ($ (Δs[i]), @thunk ($ (∂s[i])))) for i in 1 : length (∂s)]
234+ return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
275235end
276236
277237"""
0 commit comments