Skip to content

Commit bc830c6

Browse files
authored
Merge pull request #395 from mcabbott/scalarmacro
Always include `ProjectTo` in `@scalar_rule`
2 parents 6debcc3 + 0d6593e commit bc830c6

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

src/rule_definition_tools.jl

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ end
108108
109109
returns (in order) the correctly escaped:
110110
- `call` with out any type constraints
111-
- `setup_stmts`: the content of `@setup` or `nothing` if that is not provided,
111+
- `setup_stmts`: the content of `@setup` or `[]` if that is not provided,
112112
- `inputs`: with all args having the constraints removed from call, or
113113
defaulting to `Number`
114114
- `partials`: which are all `Expr{:tuple,...}`
@@ -118,9 +118,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
118118
# Setup: normalizing input form etc
119119

120120
if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
121-
setup_stmts = map(esc, maybe_setup.args[3:end])
121+
setup_stmts = Any[esc(ex) for ex in maybe_setup.args[3:end]]
122122
else
123-
setup_stmts = (nothing,)
123+
setup_stmts = []
124124
partials = (maybe_setup, partials...)
125125
end
126126
@assert Meta.isexpr(call, :call)
@@ -185,10 +185,14 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
185185
# because this is a pull-back there is one per output of function
186186
Δs = _propagator_inputs(n_outputs)
187187

188+
# Make a projector for each argument
189+
projs, psetup = _make_projectors(call.args[2:end])
190+
append!(setup_stmts, psetup)
191+
188192
# 1 partial derivative per input
189193
pullback_returns = map(1:n_inputs) do input_i
190194
∂s = [partial.args[input_i] for partial in partials]
191-
propagation_expr(Δs, ∂s, true)
195+
propagation_expr(Δs, ∂s, true, projs[input_i])
192196
end
193197

194198
# Multi-output functions have pullbacks with a tuple input that will be destructured
@@ -215,14 +219,23 @@ end
215219
"Declares properly hygenic inputs for propagation expressions"
216220
_propagator_inputs(n) = [esc(gensym(Symbol(, i))) for i in 1:n]
217221

222+
"given the variable names, escaped but without types, makes setup expressions for projection operators"
223+
function _make_projectors(xs)
224+
projs = map(x -> Symbol(:proj_, x.args[1]), xs)
225+
setups = map((x,p) -> :($p = ProjectTo($x)), xs, projs)
226+
return projs, setups
227+
end
228+
218229
"""
219-
propagation_expr(Δs, ∂s, _conj = false)
230+
propagation_expr(Δs, ∂s, [_conj=false, proj=identity])
220231
221-
Returns the expression for the propagation of
222-
the input gradient `Δs` though the partials `∂s`.
223-
Specify `_conj = true` to conjugate the partials.
232+
Returns the expression for the propagation of
233+
the input gradient `Δs` though the partials `∂s`.
234+
Specify `_conj = true` to conjugate the partials.
235+
Projector `proj` is a function that will be applied at the end;
236+
for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity`
224237
"""
225-
function propagation_expr(Δs, ∂s, _conj = false)
238+
function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
226239
# This is basically Δs ⋅ ∂s
227240
_∂s = map(∂s) do ∂s_i
228241
if _conj
@@ -249,7 +262,7 @@ function propagation_expr(Δs, ∂s, _conj = false)
249262
:($(_∂s[1]) * $(Δs[1]))
250263
end
251264

252-
return summed_∂_mul_Δs
265+
return :($proj($summed_∂_mul_Δs))
253266
end
254267

255268
"""

test/rule_definition_tools.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,18 @@ end
237237
@testisa Tangent{Tuple{Irrational{}, Float64}, Tuple{Float32, Float32}}
238238
end
239239

240+
@testset "@scalar_rule projection" begin
241+
make_imaginary(x) = im*x
242+
@scalar_rule make_imaginary(x) im
243+
244+
# note: the === will make sure that these are Float64, not ComplexF64
245+
@test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0*im)
246+
@test (NoTangent(), 0.0) === rrule(make_imaginary, 2.0)[2](1.0)
247+
248+
@test (NoTangent(), 1.0+0.0im) === rrule(make_imaginary, 2.0im)[2](1.0*im)
249+
@test (NoTangent(), 0.0-1.0im) === rrule(make_imaginary, 2.0im)[2](1.0)
250+
end
251+
240252
@testset "Regression tests against #276 and #265" begin
241253
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276
242254
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265

0 commit comments

Comments
 (0)