Skip to content

Commit 9007ad5

Browse files
authored
Add implementation of Lion optimiser (#129)
1 parent e2254b4 commit 9007ad5

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export destructure
1414
include("rules.jl")
1515
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
1616
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
17-
WeightDecay, ClipGrad, ClipNorm, OptimiserChain
17+
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion
1818

1919
###
2020
### one-array functions

src/rules.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,36 @@ function apply!(o::Adam, state, x, dx)
217217
return (mt, vt, βt .* β), dx′
218218
end
219219

220+
"""
221+
Lion(η = 0.001, β::Tuple = (0.9, 0.999))
222+
223+
[Lion](https://arxiv.org/abs/2302.06675) optimiser.
224+
225+
# Parameters
226+
- Learning rate (`η`): Magnitude by which gradients are updating the weights.
227+
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
228+
second (β2) momentum estimate.
229+
"""
230+
struct Lion{T} <: AbstractRule
231+
eta::T
232+
beta::Tuple{T,T}
233+
end
234+
Lion= 1f-3, β = (9f-1, 9.99f-1)) = Lion{typeof(η)}(η, β)
235+
236+
init(o::Lion, x::AbstractArray) = zero(x)
237+
238+
function apply!(o::Lion, state, x, dx)
239+
η, β = o.eta, o.beta
240+
241+
@.. state = β[2] * dx + (1-β[2]) * state
242+
243+
# The paper writes the update in terms of the old momentum,
244+
# but easy to solve in terms of the current momentum instead:
245+
dx′ = @lazy η * sign((β[2]-β[1]) * dx + β[1] * state)
246+
247+
return state, dx′
248+
end
249+
220250
"""
221251
RAdam(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
222252

test/rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ RULES = [
88
# All the rules at default settings:
99
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
1010
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
11-
AdamW(), RAdam(), OAdam(), AdaBelief(),
11+
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
1212
# A few chained combinations:
1313
OptimiserChain(WeightDecay(), Adam(0.001)),
1414
OptimiserChain(ClipNorm(), Adam(0.001)),

0 commit comments

Comments
 (0)