@@ -74,14 +74,8 @@ macro scalar_rule(call, maybe_setup, partials...)
7474 )
7575 f = call. args[1 ]
7676
77- # An expression that when evaluated will return the type of the input domain.
78- # Multiple repetitions of this expression should optimize out. But if it does not then
79- # may need to move its definition into the body of the `rrule`/`frule`
80- 𝒟 = :(typeof (first (promote ($ (call. args[2 : end ]. .. )))))
81-
82- frule_expr = scalar_frule_expr (𝒟, f, call, setup_stmts, inputs, partials)
83- rrule_expr = scalar_rrule_expr (𝒟, f, call, setup_stmts, inputs, partials)
84-
77+ frule_expr = scalar_frule_expr (f, call, setup_stmts, inputs, partials)
78+ rrule_expr = scalar_rrule_expr (f, call, setup_stmts, inputs, partials)
8579
8680 # ###########################################################################
8781 # Final return: building the expression to insert in the place of this macro
@@ -147,7 +141,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
147141 return call, setup_stmts, inputs, partials
148142end
149143
150- function scalar_frule_expr (𝒟, f, call, setup_stmts, inputs, partials)
144+ function scalar_frule_expr (f, call, setup_stmts, inputs, partials)
151145 n_outputs = length (partials)
152146 n_inputs = length (inputs)
153147
@@ -156,7 +150,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
156150 Δs = [Symbol (string (:Δ , i)) for i in 1 : n_inputs]
157151 pushforward_returns = map (1 : n_outputs) do output_i
158152 ∂s = partials[output_i]. args
159- propagation_expr (𝒟, Δs, ∂s)
153+ propagation_expr (Δs, ∂s)
160154 end
161155 if n_outputs > 1
162156 # For forward-mode we only return a tuple if output actually a tuple.
@@ -182,7 +176,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
182176 end
183177end
184178
185- function scalar_rrule_expr (𝒟, f, call, setup_stmts, inputs, partials)
179+ function scalar_rrule_expr (f, call, setup_stmts, inputs, partials)
186180 n_outputs = length (partials)
187181 n_inputs = length (inputs)
188182
@@ -193,7 +187,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
193187 # 1 partial derivative per input
194188 pullback_returns = map (1 : n_inputs) do input_i
195189 ∂s = [partial. args[input_i] for partial in partials]
196- propagation_expr (𝒟, Δs, ∂s)
190+ propagation_expr (Δs, ∂s)
197191 end
198192
199193 pullback = quote
@@ -212,30 +206,14 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
212206end
213207
214208"""
215- propagation_expr(𝒟, Δs, ∂s)
209+ propagation_expr(Δs, ∂s)
216210
217211 Returns the expression for the propagation of
218212 the input gradient `Δs` though the partials `∂s`.
219-
220- 𝒟 is an expression that when evaluated returns the type-of the input domain.
221- For example if the derivative is being taken at the point `1` it returns `Int`.
222- if it is taken at `1+1im` it returns `Complex{Int}`.
223- At present it is ignored for non-Wirtinger derivatives.
224213"""
225- function propagation_expr (𝒟, Δs, ∂s)
226- wirtinger_indices = findall (∂s) do ex
227- Meta. isexpr (ex, :call ) && ex. args[1 ] === :Wirtinger
228- end
229- ∂s = map (esc, ∂s)
230- if isempty (wirtinger_indices)
231- return standard_propagation_expr (Δs, ∂s)
232- else
233- return wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
234- end
235- end
236-
237- function standard_propagation_expr (Δs, ∂s)
214+ function propagation_expr (Δs, ∂s)
238215 # This is basically Δs ⋅ ∂s
216+ ∂s = map (esc, ∂s)
239217
240218 # Notice: the thunking of `∂s[i] (potentially) saves us some computation
241219 # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
@@ -244,36 +222,6 @@ function standard_propagation_expr(Δs, ∂s)
244222 return :(+ ($ (∂_mul_Δs... )))
245223end
246224
247- function wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
248- ∂_mul_Δs_primal = Any[]
249- ∂_mul_Δs_conjugate = Any[]
250- ∂_wirtinger_defs = Any[]
251- for i in 1 : length (∂s)
252- if i in wirtinger_indices
253- Δi = Δs[i]
254- ∂i = Symbol (string (:∂ , i))
255- push! (∂_wirtinger_defs, :($ ∂i = $ (∂s[i])))
256- ∂f∂i_mul_Δ = :(wirtinger_primal ($ ∂i) * wirtinger_primal ($ Δi))
257- ∂f∂ī_mul_Δ̄ = :(conj (wirtinger_conjugate ($ ∂i)) * wirtinger_conjugate ($ Δi))
258- ∂f̄∂i_mul_Δ = :(wirtinger_conjugate ($ ∂i) * wirtinger_primal ($ Δi))
259- ∂f̄∂ī_mul_Δ̄ = :(conj (wirtinger_primal ($ ∂i)) * wirtinger_conjugate ($ Δi))
260- push! (∂_mul_Δs_primal, :($ ∂f∂i_mul_Δ + $ ∂f∂ī_mul_Δ̄))
261- push! (∂_mul_Δs_conjugate, :($ ∂f̄∂i_mul_Δ + $ ∂f̄∂ī_mul_Δ̄))
262- else
263- ∂_mul_Δ = :(@thunk ($ (∂s[i])) * $ (Δs[i]))
264- push! (∂_mul_Δs_primal, ∂_mul_Δ)
265- push! (∂_mul_Δs_conjugate, ∂_mul_Δ)
266- end
267- end
268- primal_sum = :(+ ($ (∂_mul_Δs_primal... )))
269- conjugate_sum = :(+ ($ (∂_mul_Δs_conjugate... )))
270- return quote # This will be a block, so will have value equal to last statement
271- $ (∂_wirtinger_defs... )
272- w = Wirtinger ($ primal_sum, $ conjugate_sum)
273- refine_differential ($ 𝒟, w)
274- end
275- end
276-
277225"""
278226 propagator_name(f, propname)
279227
0 commit comments