Skip to content

Commit e6d8160

Browse files
Rule for gradient accumulation (#137)
* implement AccumGrad * gradient accumulation * new interface * more tests * remove NoUpdaete * fix * test for subtract! Zero * fix * don't test AccumGrad with other rules * another variant * Update src/rules.jl Co-authored-by: Michael Abbott <[email protected]> * less docs --------- Co-authored-by: Michael Abbott <[email protected]>
1 parent 14949f1 commit e6d8160

File tree

7 files changed

+100
-5
lines changed

7 files changed

+100
-5
lines changed

src/Optimisers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ export destructure
1414
include("rules.jl")
1515
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
1616
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
17-
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion
17+
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
18+
AccumGrad
1819

1920
###
2021
### one-array functions

src/adjust.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,3 @@ function _adjust(r::T, nt::NamedTuple) where T <: AbstractRule
144144
end
145145
T(vals...) # relies on having the default constructor
146146
end
147-

src/backup.jl

Whitespace-only changes.

src/interface.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
21
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent
2+
33
base(dx::Tangent) = backing(canonicalize(dx))
44
base(dx) = dx
5+
56
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
67

78
abstract type AbstractRule end
@@ -96,6 +97,7 @@ function _update!(ℓ::Leaf, x; grads, params)
9697
end
9798

9899
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
100+
subtract!(x, x̄::Zero) = x
99101

100102
_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing
101103
function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...)
@@ -222,3 +224,4 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
222224

223225
onevalue::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
224226
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)
227+

src/rules.jl

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ so `update!` will subtract the full gradient from the parameters.
631631
This is equivalent to `Descent(1)`.
632632
633633
# Example
634+
634635
```jldoctest
635636
julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
636637
@@ -654,8 +655,12 @@ init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts)
654655

655656
function apply!(o::OptimiserChain, states, x, dx, dxs...)
656657
foldl(tuple.(o.opts, states); init = ((), dx)) do (states′, dx′), (opt, state)
657-
state′, dx′ = apply!(opt, state, x, dx′, dxs...)
658-
return (states′..., state′), dx′
658+
if dx′ isa Zero
659+
return (states′..., state), dx′
660+
else
661+
state′, dx′ = apply!(opt, state, x, dx′, dxs...)
662+
return (states′..., state′), dx′
663+
end
659664
end
660665
end
661666

@@ -667,3 +672,60 @@ end
667672

668673
adjust(ℓ::OptimiserChain, eta::Real) = OptimiserChain(map(opt -> adjust(opt, eta), ℓ.opts)...)
669674
adjust(ℓ::OptimiserChain; kw...) = OptimiserChain(map(opt -> adjust(opt; kw...), ℓ.opts)...)
675+
676+
677+
"""
678+
AccumGrad(n::Int)
679+
680+
A rule constructed `OptimiserChain(AccumGrad(n), Rule())` will accumulate for `n` steps,
681+
before applying `Rule` to the mean of these `n` gradients.
682+
683+
This is useful for training with effective batch sizes too large for the available memory.
684+
Instead of computing the gradient for batch size `b` at once, compute it for size `b/n` and
685+
accumulate `n` such gradients.
686+
687+
# Example
688+
```jldoctest
689+
julia> m = (x=[1f0], y=[2f0]);
690+
691+
julia> r = OptimiserChain(AccumGrad(2), WeightDecay(0.01), Descent(0.1));
692+
693+
julia> s = Optimisers.setup(r, m);
694+
695+
julia> Optimisers.update!(s, m, (x=[33], y=[0]));
696+
697+
julia> m # model not yet changed
698+
(x = Float32[1.0], y = Float32[2.0])
699+
700+
julia> Optimisers.update!(s, m, (x=[0], y=[444]));
701+
702+
julia> m # n=2 gradients applied at once
703+
(x = Float32[-0.651], y = Float32[-20.202])
704+
```
705+
"""
706+
struct AccumGrad <: AbstractRule
707+
n::Int
708+
709+
function AccumGrad(n::Int)
710+
n > 0 || throw(ArgumentError("AccumGrad must accumulate at least one gradient"))
711+
return new(n)
712+
end
713+
end
714+
715+
function init(o::AccumGrad, x)
716+
return (zero(x), 1)
717+
end
718+
719+
function apply!(o::AccumGrad, state, x, dx)
720+
accum_dx, counter = state
721+
if counter == 1
722+
@.. accum_dx = dx / o.n
723+
else
724+
@.. accum_dx = accum_dx + dx / o.n
725+
end
726+
if counter == o.n
727+
return (accum_dx, 1), accum_dx
728+
else
729+
return (accum_dx, counter + 1), nothing
730+
end
731+
end

test/rules.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,26 @@ VERSION < v"1.9-" && @testset "using Yota" begin
244244
@test loss(w, w′) < 0.001
245245
end
246246
end
247+
248+
@testset "AccumGrad" begin
249+
x0 = rand(5)
250+
x = copy(x0)
251+
lr = 0.01
252+
tree = Optimisers.setup(OptimiserChain(AccumGrad(3), Descent(lr)), x)
253+
254+
g1 = rand(5)
255+
tree, x1 = Optimisers.update(tree, x, g1)
256+
@test x1 x
257+
@test x1 x0
258+
g2 = rand(5)
259+
tree, x2 = Optimisers.update(tree, x1, g2)
260+
@test x2 x
261+
@test x2 x0
262+
g3 = rand(5)
263+
tree, x3 = Optimisers.update(tree, x2, g3)
264+
@test x3 x0 .- lr .* (g1 .+ g2 .+ g3) ./ 3
265+
g4 = rand(5)
266+
267+
tree, x4 = Optimisers.update(tree, x3, g4)
268+
@test x4 x3
269+
end

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,13 @@ y2z(x) = x
499499
end
500500
end # 2nd-order
501501

502+
@testset "subtract! handles Zero" begin
503+
x = rand(3)
504+
y = Optimisers.subtract!(x, ChainRulesCore.ZeroTangent())
505+
@test y === x
506+
y = Optimisers.subtract!(x, nothing)
507+
@test y === x
508+
end
502509
end
503510
@testset verbose=true "Destructure" begin
504511
include("destructure.jl")

0 commit comments

Comments
 (0)