108
108
109
109
returns (in order) the correctly escaped:
110
110
- `call` with out any type constraints
111
- - `setup_stmts`: the content of `@setup` or `nothing ` if that is not provided,
111
+ - `setup_stmts`: the content of `@setup` or `[] ` if that is not provided,
112
112
- `inputs`: with all args having the constraints removed from call, or
113
113
defaulting to `Number`
114
114
- `partials`: which are all `Expr{:tuple,...}`
@@ -118,9 +118,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
118
118
# Setup: normalizing input form etc
119
119
120
120
if Meta. isexpr (maybe_setup, :macrocall ) && maybe_setup. args[1 ] == Symbol (" @setup" )
121
- setup_stmts = map ( esc, maybe_setup. args[3 : end ])
121
+ setup_stmts = Any[ esc (ex) for ex in maybe_setup. args[3 : end ]]
122
122
else
123
- setup_stmts = ( nothing ,)
123
+ setup_stmts = []
124
124
partials = (maybe_setup, partials... )
125
125
end
126
126
@assert Meta. isexpr (call, :call )
@@ -185,10 +185,14 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
185
185
# because this is a pull-back there is one per output of function
186
186
Δs = _propagator_inputs (n_outputs)
187
187
188
+ # Make a projector for each argument
189
+ ps, setup_ps = _make_projectors (call. args[2 : end ])
190
+ append! (setup_stmts, setup_ps)
191
+
188
192
# 1 partial derivative per input
189
193
pullback_returns = map (1 : n_inputs) do input_i
190
194
∂s = [partial. args[input_i] for partial in partials]
191
- propagation_expr (Δs, ∂s, true )
195
+ propagation_expr (Δs, ∂s, true , ps[input_i] )
192
196
end
193
197
194
198
# Multi-output functions have pullbacks with a tuple input that will be destructured
@@ -215,14 +219,22 @@ end
215
219
" Declares properly hygenic inputs for propagation expressions"
216
220
_propagator_inputs (n) = [esc (gensym (Symbol (:Δ , i))) for i in 1 : n]
217
221
222
+ " given the variable names, escaped but without types, makes setup expressions for projection operators"
223
+ function _make_projectors (xs)
224
+ ps = map (x -> gensym (Symbol (:proj_ , x. args[1 ])), xs)
225
+ setups = map ((x,p) -> :($ p = ProjectTo ($ x)), xs, ps)
226
+ return ps, setups
227
+ end
228
+
218
229
"""
219
- propagation_expr(Δs, ∂s, _conj = false)
230
+ propagation_expr(Δs, ∂s, [ _conj = false, proj = identity] )
220
231
221
- Returns the expression for the propagation of
222
- the input gradient `Δs` though the partials `∂s`.
223
- Specify `_conj = true` to conjugate the partials.
232
+ Returns the expression for the propagation of
233
+ the input gradient `Δs` though the partials `∂s`.
234
+ Specify `_conj = true` to conjugate the partials.
235
+ Projector `proj`, usually `ProjectTo(x)`, is applied at the end.
224
236
"""
225
- function propagation_expr (Δs, ∂s, _conj = false )
237
+ function propagation_expr (Δs, ∂s, _conj = false , proj = identity )
226
238
# This is basically Δs ⋅ ∂s
227
239
_∂s = map (∂s) do ∂s_i
228
240
if _conj
@@ -249,7 +261,7 @@ function propagation_expr(Δs, ∂s, _conj = false)
249
261
:($ (_∂s[1 ]) * $ (Δs[1 ]))
250
262
end
251
263
252
- return summed_∂_mul_Δs
264
+ return :( $ proj ( $ summed_∂_mul_Δs))
253
265
end
254
266
255
267
"""
0 commit comments