Skip to content

Commit 44fc0ce

Browse files
authored
Add gradient clipping (#27)
* add gradient clipping * include in OptimiserChain tests * from review * docs * rebase fixup * wording
1 parent 755a97a commit 44fc0ce

File tree

5 files changed

+112
-24
lines changed

5 files changed

+112
-24
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

docs/src/api.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ Optimisers.AMSGrad
1313
Optimisers.NADAM
1414
Optimisers.ADAMW
1515
Optimisers.AdaBelief
16+
```
17+
18+
```
19+
Optimisers.ClipGrad
20+
Optimisers.ClipNorm
1621
Optimisers.WeightDecay
1722
Optimisers.OptimiserChain
1823
```

src/Optimisers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module Optimisers
22

33
using Functors: functor, fmap, isleaf
4+
using LinearAlgebra
45

56
include("interface.jl")
67
include("rules.jl")
78

89
export Descent, ADAM, Momentum, Nesterov, RMSProp,
910
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
10-
WeightDecay, OptimiserChain
11+
WeightDecay, ClipGrad, ClipNorm, OptimiserChain
1112

1213
end # module

src/rules.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,61 @@ function apply(o::WeightDecay, state, x, dx)
490490
return state, dx′
491491
end
492492

493+
"""
494+
ClipGrad(δ = 10f0)
495+
496+
Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.
497+
498+
See also [`ClipNorm`](@ref).
499+
"""
500+
struct ClipGrad{T<:Real}
501+
delta::T
502+
end
503+
ClipGrad() = ClipGrad(10f0)
504+
505+
init(o::ClipGrad, x::AbstractArray) = nothing
506+
507+
(o::ClipGrad)(state::Nothing, m, dm) = update(o, state, m, dm)
508+
509+
function apply(o::ClipGrad, state, x, dx)
510+
δ = convert(eltype(dx), o.delta)
511+
dx′ = @. clamp(dx, -δ, δ)
512+
513+
return state, dx′
514+
end
515+
516+
"""
517+
ClipNorm(ω = 10f0, p = 2; throw = true)
518+
519+
Scales any gradient array for which `norm(dx, p) > ω`
520+
to stay at this threshold (unless `p==0`).
521+
522+
Throws an error if the norm is infinite or `NaN`,
523+
which you can turn off with `throw = false`.
524+
525+
See also [`ClipGrad`](@ref).
526+
"""
527+
struct ClipNorm{T<:Real}
528+
omega::T
529+
p::T
530+
throw::Bool
531+
end
532+
ClipNorm= 10f0, p = 2; throw::Bool = true) = ClipNorm{typeof(ω)}(ω, p, throw)
533+
534+
init(o::ClipNorm, x::AbstractArray) = nothing
535+
536+
(o::ClipNorm)(state::Nothing, m, dm) = update(o, state, m, dm)
537+
538+
function apply(o::ClipNorm, state, x, dx)
539+
nrm = norm(dx, o.p)
540+
if o.throw && !isfinite(nrm)
541+
throw(DomainError("gradient has $(o.p)-norm $nrm, for array $(summary(x))"))
542+
end
543+
λ = min(o.omega / nrm, 1)
544+
545+
return state, @. dx * λ
546+
end
547+
493548
"""
494549
OptimiserChain(opts...)
495550

test/runtests.jl

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using Optimisers, Test
2-
using Zygote, Random
3-
using Statistics
2+
using Zygote
3+
using Statistics, Random, LinearAlgebra
4+
Random.seed!(84)
5+
6+
@testset verbose=true "Optimisers.jl" begin
47

5-
@testset "Optimisers" begin
6-
Random.seed!(84)
7-
w′ == rand(3, 3), β = rand(3, 3))
88
@testset for o in (Descent(), ADAM(), Momentum(), Nesterov(), RMSProp(),
99
ADAGrad(), AdaMax(), ADADelta(), AMSGrad(), NADAM(),
1010
ADAMW(), RADAM(), OADAM(), AdaBelief())
11+
w′ == rand(3, 3), β = rand(3, 3))
1112

1213
# Original example
1314
w == 5rand(3, 3), β = rand(3, 3))
@@ -37,25 +38,50 @@ using Statistics
3738
end
3839

3940
end
40-
end
4141

42-
@testset "OptimiserChain" begin
43-
Random.seed!(84)
44-
w = randn(10, 10)
45-
w′ = randn(10, 10)
46-
loss(x, w, w′) = mean((w*x .- w′*x) .^ 2)
47-
opt = OptimiserChain(WeightDecay(), ADAM(0.001))
48-
st = Optimisers.state(opt, w)
49-
for t = 1:10^5
50-
x = rand(10)
51-
gs = gradient(w -> loss(x, w, w′), w)
52-
st, w = Optimisers.update(opt, st, w, gs...)
42+
@testset "OptimiserChain with $pre" for pre in (WeightDecay(), ClipGrad(), ClipNorm())
43+
Random.seed!(84)
44+
w = randn(10, 10)
45+
w′ = randn(10, 10)
46+
loss(x, w, w′) = mean((w*x .- w′*x) .^ 2)
47+
@test loss(rand(10, 10), w, w′) > 1
48+
opt = OptimiserChain(pre, ADAM(0.001))
49+
st = Optimisers.init(opt, w)
50+
for t = 1:10^5
51+
x = rand(10)
52+
gs = gradient(w -> loss(x, w, w′), w)
53+
st, w = Optimisers.update(opt, st, w, gs...)
54+
end
55+
@test loss(rand(10, 10), w, w′) < 0.01
56+
end
57+
58+
@testset "gradient clipping" begin
59+
@test_skip m == ([0], sin), γ = rand(3)) # https://github.com/FluxML/Optimisers.jl/issues/28
60+
m == ([0], [0]), γ = rand(3))
61+
c1 = ClipGrad(13)
62+
s1 = Optimisers.state(c1, m)
63+
_, g1 = Optimisers.update(c1, s1, m, (α = nothing, γ = [1,10,100],))
64+
@test m.γ .- g1.γ [1, 10, 13]
65+
66+
c2 = ClipNorm(10)
67+
s2 = Optimisers.state(c2, m)
68+
_, g2 = Optimisers.update(c2, s2, m, (α = ([0.1], nothing), γ = [1,10,100],))
69+
@test only(m.α[1] .- g2.α[1]) 0.1
70+
@test norm(m.γ .- g2.γ) 10
71+
@test_throws DomainError Optimisers.update(c2, s2, m, (α = [0.1], γ = [1,10,NaN],))
72+
73+
c3 = ClipNorm(5, 1; throw=false)
74+
_, g3 = Optimisers.update(c3, s2, m, (α = ([0.1], nothing), γ = [1,10,100],))
75+
@test only(m.α[1] .- g3.α[1]) 0.1
76+
@test norm(m.γ .- g3.γ, 1) 5
77+
_, g3n = Optimisers.update(c3, s2, m, (α = nothing, γ = [1,10,Inf],))
78+
@test isnan(g3n.γ[3])
79+
end
80+
81+
@testset "Optimiser Updates" begin
82+
opt = ADAM()
83+
new_opt = ADAM(opt, eta = 9.f0)
84+
@test new_opt.eta == 9.f0
5385
end
54-
@test loss(rand(10, 10), w, w′) < 0.01
55-
end
5686

57-
@testset "Optimiser Updates" begin
58-
opt = ADAM()
59-
new_opt = ADAM(opt, eta = 9.f0)
60-
@test new_opt.eta == 9.f0
6187
end

0 commit comments

Comments
 (0)