Skip to content

Commit 5fc9e5d

Browse files
mcabbottoxinabox
authored andcommitted
always include projector in at_scalar_rule
1 parent 6debcc3 commit 5fc9e5d

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

src/rule_definition_tools.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ end
108108
109109
returns (in order) the correctly escaped:
110110
- `call` with out any type constraints
111-
- `setup_stmts`: the content of `@setup` or `nothing` if that is not provided,
111+
- `setup_stmts`: the content of `@setup` or `[]` if that is not provided,
112112
- `inputs`: with all args having the constraints removed from call, or
113113
defaulting to `Number`
114114
- `partials`: which are all `Expr{:tuple,...}`
@@ -118,9 +118,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
118118
# Setup: normalizing input form etc
119119

120120
if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
121-
setup_stmts = map(esc, maybe_setup.args[3:end])
121+
setup_stmts = Any[esc(ex) for ex in maybe_setup.args[3:end]]
122122
else
123-
setup_stmts = (nothing,)
123+
setup_stmts = []
124124
partials = (maybe_setup, partials...)
125125
end
126126
@assert Meta.isexpr(call, :call)
@@ -185,10 +185,14 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
185185
# because this is a pull-back there is one per output of function
186186
Δs = _propagator_inputs(n_outputs)
187187

188+
# Make a projector for each argument
189+
ps, setup_ps = _make_projectors(call.args[2:end])
190+
append!(setup_stmts, setup_ps)
191+
188192
# 1 partial derivative per input
189193
pullback_returns = map(1:n_inputs) do input_i
190194
∂s = [partial.args[input_i] for partial in partials]
191-
propagation_expr(Δs, ∂s, true)
195+
propagation_expr(Δs, ∂s, true, ps[input_i])
192196
end
193197

194198
# Multi-output functions have pullbacks with a tuple input that will be destructured
@@ -215,14 +219,22 @@ end
215219
"Declares properly hygenic inputs for propagation expressions"
216220
_propagator_inputs(n) = [esc(gensym(Symbol(, i))) for i in 1:n]
217221

222+
"given the variable names, escaped but without types, makes setup expressions for projection operators"
223+
function _make_projectors(xs)
224+
ps = map(x -> gensym(Symbol(:proj_, x.args[1])), xs)
225+
setups = map((x,p) -> :($p = ProjectTo($x)), xs, ps)
226+
return ps, setups
227+
end
228+
218229
"""
219-
propagation_expr(Δs, ∂s, _conj = false)
230+
propagation_expr(Δs, ∂s, [_conj = false, proj = identity])
220231
221-
Returns the expression for the propagation of
222-
the input gradient `Δs` though the partials `∂s`.
223-
Specify `_conj = true` to conjugate the partials.
232+
Returns the expression for the propagation of
233+
the input gradient `Δs` though the partials `∂s`.
234+
Specify `_conj = true` to conjugate the partials.
235+
Projector `proj`, usually `ProjectTo(x)`, is applied at the end.
224236
"""
225-
function propagation_expr(Δs, ∂s, _conj = false)
237+
function propagation_expr(Δs, ∂s, _conj = false, proj = identity)
226238
# This is basically Δs ⋅ ∂s
227239
_∂s = map(∂s) do ∂s_i
228240
if _conj
@@ -249,7 +261,7 @@ function propagation_expr(Δs, ∂s, _conj = false)
249261
:($(_∂s[1]) * $(Δs[1]))
250262
end
251263

252-
return summed_∂_mul_Δs
264+
return :($proj($summed_∂_mul_Δs))
253265
end
254266

255267
"""

0 commit comments

Comments
 (0)