Skip to content

Commit 4a78a55

Browse files
fix epsilon for Float16 (#190)
1 parent 2da6d7f commit 4a78a55

File tree

5 files changed

+29
-11
lines changed

5 files changed

+29
-11
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ Manifest.toml
22
.vscode/
33
docs/build/
44
.DS_Store
5+
/test.jl

src/Optimisers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
2525
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
2626
AccumGrad
2727

28+
VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))
29+
2830
###
2931
### one-array functions
3032
###

src/rules.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, eps
130130
init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false)
131131

132132
function apply!(o::RMSProp, state, x::AbstractArray{T}, dx) where T
133-
η, ρ, ϵ = T(o.eta), T(o.rho), T(o.epsilon)
133+
η, ρ, ϵ = T(o.eta), T(o.rho), _eps(T, o.epsilon)
134134
quad, lin = state
135135

136136
@.. quad = ρ * quad + (1 - ρ) * abs2(dx)
@@ -216,7 +216,7 @@ end
216216
init(o::Adam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
217217

218218
function apply!(o::Adam, state, x::AbstractArray{T}, dx) where T
219-
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
219+
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
220220
mt, vt, βt = state
221221

222222
@.. mt = β[1] * mt + (1 - β[1]) * dx
@@ -279,7 +279,7 @@ end
279279
init(o::RAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), 1)
280280

281281
function apply!(o::RAdam, state, x::AbstractArray{T}, dx) where T
282-
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
282+
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
283283
ρ∞ = 2/(1-β[2]) - 1 |> real
284284

285285
mt, vt, βt, t = state
@@ -320,7 +320,7 @@ end
320320
init(o::AdaMax, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
321321

322322
function apply!(o::AdaMax, state, x::AbstractArray{T}, dx) where T
323-
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
323+
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
324324
mt, ut, βt = state
325325

326326
@.. mt = β[1] * mt + (1 - β[1]) * dx
@@ -354,7 +354,7 @@ end
354354
init(o::OAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), zero(x))
355355

356356
function apply!(o::OAdam, state, x::AbstractArray{T}, dx) where T
357-
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
357+
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
358358
mt, vt, βt, term = state
359359

360360
@.. mt = β[1] * mt + (1 - β[1]) * dx
@@ -388,7 +388,7 @@ end
388388
init(o::AdaGrad, x::AbstractArray) = onevalue(o.epsilon, x)
389389

390390
function apply!(o::AdaGrad, state, x::AbstractArray{T}, dx) where T
391-
η, ϵ = T(o.eta), T(o.epsilon)
391+
η, ϵ = T(o.eta), _eps(T, o.epsilon)
392392
acc = state
393393

394394
@.. acc = acc + abs2(dx)
@@ -418,7 +418,7 @@ end
418418
init(o::AdaDelta, x::AbstractArray) = (zero(x), zero(x))
419419

420420
function apply!(o::AdaDelta, state, x::AbstractArray{T}, dx) where T
421-
ρ, ϵ = T(o.rho), T(o.epsilon)
421+
ρ, ϵ = T(o.rho), _eps(T, o.epsilon)
422422
acc, Δacc = state
423423

424424
@.. acc = ρ * acc + (1 - ρ) * abs2(dx)
@@ -454,7 +454,7 @@ init(o::AMSGrad, x::AbstractArray) =
454454
(onevalue(o.epsilon, x), onevalue(o.epsilon, x), onevalue(o.epsilon, x))
455455

456456
function apply!(o::AMSGrad, state, x::AbstractArray{T}, dx) where T
457-
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
457+
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
458458
mt, vt, v̂t = state
459459

460460
@.. mt = β[1] * mt + (1 - β[1]) * dx
@@ -489,8 +489,7 @@ end
489489
init(o::NAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
490490

491491
function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
492-
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
493-
492+
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
494493
mt, vt, βt = state
495494

496495
@.. mt = β[1] * mt + (1 - β[1]) * dx
@@ -548,7 +547,7 @@ end
548547
init(o::AdaBelief, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
549548

550549
function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
551-
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
550+
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
552551
mt, st, βt = state
553552

554553
@.. mt = β[1] * mt + (1 - β[1]) * dx

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,9 @@ foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
1414
end
1515

1616
ofeltype(x, y) = convert(float(eltype(x)), y)
17+
18+
_eps(T::Type{<:AbstractFloat}, e) = T(e)
19+
# catch complex and integers
20+
_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e)
21+
# avoid small e being rounded to zero
22+
_eps(T::Type{Float16}, e) = e == 0 ? T(0) : max(T(1e-7), T(e))

test/rules.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,13 @@ end
267267
tree, x4 = Optimisers.update(tree, x3, g4)
268268
@test x4 x3
269269
end
270+
271+
@testset "Float16 epsilon" begin
272+
# issue https://github.com/FluxML/Optimisers.jl/issues/167
273+
x = Float16[0.579, -0.729, 0.5493]
274+
δx = Float16[-0.001497, 0.0001875, -0.013176]
275+
276+
os = Optimisers.setup(Adam(1e-4), x);
277+
os, x = Optimisers.update(os, x, δx)
278+
@test x Float16[1.835, -0.886, 0.5493] rtol=1e-3
279+
end

0 commit comments

Comments
 (0)