108
108
109
109
returns (in order) the correctly escaped:
110
110
- `call` with out any type constraints
111
- - `setup_stmts`: the content of `@setup` or `nothing ` if that is not provided,
111
+ - `setup_stmts`: the content of `@setup` or `[] ` if that is not provided,
112
112
- `inputs`: with all args having the constraints removed from call, or
113
113
defaulting to `Number`
114
114
- `partials`: which are all `Expr{:tuple,...}`
@@ -118,9 +118,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
118
118
# Setup: normalizing input form etc
119
119
120
120
if Meta. isexpr (maybe_setup, :macrocall ) && maybe_setup. args[1 ] == Symbol (" @setup" )
121
- setup_stmts = map ( esc, maybe_setup. args[3 : end ])
121
+ setup_stmts = Any[ esc (ex) for ex in maybe_setup. args[3 : end ]]
122
122
else
123
- setup_stmts = ( nothing ,)
123
+ setup_stmts = []
124
124
partials = (maybe_setup, partials... )
125
125
end
126
126
@assert Meta. isexpr (call, :call )
@@ -185,10 +185,14 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
185
185
# because this is a pull-back there is one per output of function
186
186
Δs = _propagator_inputs (n_outputs)
187
187
188
+ # Make a projector for each argument
189
+ projs, psetup = _make_projectors (call. args[2 : end ])
190
+ append! (setup_stmts, psetup)
191
+
188
192
# 1 partial derivative per input
189
193
pullback_returns = map (1 : n_inputs) do input_i
190
194
∂s = [partial. args[input_i] for partial in partials]
191
- propagation_expr (Δs, ∂s, true )
195
+ propagation_expr (Δs, ∂s, true , projs[input_i] )
192
196
end
193
197
194
198
# Multi-output functions have pullbacks with a tuple input that will be destructured
@@ -215,14 +219,23 @@ end
215
219
" Declares properly hygenic inputs for propagation expressions"
216
220
_propagator_inputs (n) = [esc (gensym (Symbol (:Δ , i))) for i in 1 : n]
217
221
222
+ " given the variable names, escaped but without types, makes setup expressions for projection operators"
223
+ function _make_projectors (xs)
224
+ projs = map (x -> Symbol (:proj_ , x. args[1 ]), xs)
225
+ setups = map ((x,p) -> :($ p = ProjectTo ($ x)), xs, projs)
226
+ return projs, setups
227
+ end
228
+
218
229
"""
219
- propagation_expr(Δs, ∂s, _conj = false)
230
+ propagation_expr(Δs, ∂s, [ _conj= false, proj=identity] )
220
231
221
- Returns the expression for the propagation of
222
- the input gradient `Δs` though the partials `∂s`.
223
- Specify `_conj = true` to conjugate the partials.
232
+ Returns the expression for the propagation of
233
+ the input gradient `Δs` though the partials `∂s`.
234
+ Specify `_conj = true` to conjugate the partials.
235
+ Projector `proj` is a function that will be applied at the end;
236
+ for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity`
224
237
"""
225
- function propagation_expr (Δs, ∂s, _conj = false )
238
+ function propagation_expr (Δs, ∂s, _conj= false , proj = identity )
226
239
# This is basically Δs ⋅ ∂s
227
240
_∂s = map (∂s) do ∂s_i
228
241
if _conj
@@ -249,7 +262,7 @@ function propagation_expr(Δs, ∂s, _conj = false)
249
262
:($ (_∂s[1 ]) * $ (Δs[1 ]))
250
263
end
251
264
252
- return summed_∂_mul_Δs
265
+ return :( $ proj ( $ summed_∂_mul_Δs))
253
266
end
254
267
255
268
"""
0 commit comments