Skip to content

Commit 47d0202

Browse files
authored
Make Rules store a second function for updating (#42)
Now `Rule`s store two things: one callable thing, which is used for evaluating the rule, and one other thing, which can be `nothing` or a function with the signature `u(value, args...)` that evaluates the rule for the given arguments and adds the result to `value`, doing so in place when possible. The advantage of this is that we can more easily define custom ways of accumulating the results of rules, and we can even share intermediate steps between the regular rule evaluation and custom updating. As a test/proof of concept, the `rrule` for `svd` now also has an updating function that handles `NamedTuple`s appropriately.
1 parent a2e6451 commit 47d0202

File tree

4 files changed

+86
-7
lines changed

4 files changed

+86
-7
lines changed

src/rules.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,32 @@ See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
108108
"""
109109
store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...)))
110110

111+
# Special purpose updating for operations which can be done in-place. This function is
112+
# just internal and free-form; it is not a method of `accumulate!` directly as it does
113+
# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`.
114+
# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular
115+
# rule.
116+
117+
_update!(x, y) = x + y
118+
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y
119+
120+
_update!(x, ::Zero) = x
121+
_update!(::Zero, y) = y
122+
_update!(::Zero, ::Zero) = Zero()
123+
124+
function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns
125+
return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns))
126+
end
127+
128+
function _update!(x::NamedTuple, y, p::Symbol)
129+
new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),))
130+
return merge(x, new)
131+
end
132+
133+
function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns
134+
return _update!(x, getproperty(y, p), p)
135+
end
136+
111137
#####
112138
##### `Rule`
113139
#####
@@ -123,14 +149,18 @@ Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b)
123149
Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b)
124150

125151
"""
126-
Rule(propation_function)
152+
Rule(propation_function[, updating_function])
127153
128154
Return a `Rule` that wraps the given `propation_function`. It is assumed that
129155
`propation_function` is a callable object whose arguments are differential
130156
values, and whose output is a single differential value calculated by applying
131157
internally stored/computed partial derivatives to the input differential
132158
values.
133159
160+
If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)`
161+
and to store the result of the propagation function applied to the arguments `xs` into
162+
`Δ` in-place, returning `Δ`.
163+
134164
For example:
135165
136166
```
@@ -141,12 +171,21 @@ rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * Δ
141171
142172
See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref)
143173
"""
144-
struct Rule{F} <: AbstractRule
174+
struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule
145175
f::F
176+
u::U
146177
end
147178

179+
# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some
180+
# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}`
181+
Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)
182+
148183
(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...)
149184

185+
# Specialized accumulation
186+
# TODO: Does this need to be overdubbed in the rule context?
187+
accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)
188+
150189
#####
151190
##### `DNERule`
152191
#####

src/rules/linalg/factorization.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@ end
1212

1313
function rrule(::typeof(getproperty), F::SVD, x::Symbol)
1414
if x === :U
15-
return F.U, (Rule(->(U=Ȳ, S=zero(F.S), V=zero(F.V))), DNERule())
15+
rule = ->(U=Ȳ, S=zero(F.S), V=zero(F.V))
1616
elseif x === :S
17-
return F.S, (Rule(->(U=zero(F.U), S=Ȳ, V=zero(F.V))), DNERule())
17+
rule = ->(U=zero(F.U), S=Ȳ, V=zero(F.V))
1818
elseif x === :V
19-
return F.V, (Rule(->(U=zero(F.U), S=zero(F.S), V=)), DNERule())
19+
rule = ->(U=zero(F.U), S=zero(F.S), V=Ȳ)
2020
elseif x === :Vt
21-
return F.Vt, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=')), DNERule())
21+
# TODO: This could be made to work, but it'd be a pain
22+
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
2223
end
24+
update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x)
25+
return getproperty(F, x), (Rule(rule, update), DNERule())
2326
end
2427

2528
function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix)

test/rules.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,27 @@ cool(x) = x + 1
2121
end
2222
@test i == 1 # rules only iterate once, yielding themselves
2323
end
24+
@testset "helper functions" begin
25+
# Hits fallback, since we can't update `Diagonal`s in place
26+
X = Diagonal([1, 1])
27+
Y = copy(X)
28+
@test ChainRules._update!(X, [1 2; 3 4]) == [2 2; 3 5]
29+
@test X == Y # no change to X
30+
31+
X = [1 2; 3 4]
32+
Y = copy(X)
33+
@test ChainRules._update!(X, Diagonal([1, 1])) == [2 2; 3 5]
34+
@test X != Y # X has been updated
35+
36+
# Reusing above X
37+
@test ChainRules._update!(X, Zero()) === X
38+
@test ChainRules._update!(Zero(), X) === X
39+
@test ChainRules._update!(Zero(), Zero()) === Zero()
40+
41+
X = (A=[1 0; 0 1], B=[2 2; 2 2])
42+
Y = deepcopy(X)
43+
@test ChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4])
44+
@test X.A != Y.A
45+
@test X.B != Y.B
46+
end
2447
end

test/rules/linalg/factorization.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,28 @@
44
for n in [4, 6, 10], m in [3, 5, 10]
55
X = randn(rng, n, m)
66
F, dX = rrule(svd, X)
7-
for p in [:U, :S, :V, :Vt]
7+
for p in [:U, :S, :V]
88
Y, (dF, dp) = rrule(getproperty, F, p)
99
@test dp isa ChainRules.DNERule
1010
= randn(rng, size(Y)...)
1111
X̄_ad = dX(dF(Ȳ))
1212
X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)
1313
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
1414
end
15+
@test_throws ArgumentError rrule(getproperty, F, :Vt)
16+
end
17+
@testset "accumulate!" begin
18+
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
19+
F, dX = rrule(svd, X)
20+
= (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
21+
for p in [:U, :S, :V]
22+
Y, (dF, _) = rrule(getproperty, F, p)
23+
= ones(size(Y)...)
24+
ChainRules.accumulate!(X̄, dF, Ȳ)
25+
end
26+
@test.U ones(3, 2) atol=1e-6
27+
@test.S ones(2) atol=1e-6
28+
@test.V ones(2, 2) atol=1e-6
1529
end
1630
@testset "Helper functions" begin
1731
X = randn(rng, 10, 10)

0 commit comments

Comments
 (0)