Skip to content

Commit 149cc75

Browse files
tests
1 parent a9c5ef6 commit 149cc75

File tree

6 files changed

+57
-32
lines changed

6 files changed

+57
-32
lines changed

Project.toml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,4 @@ ConstructionBase = "1.5.8"
2626
EnzymeCore = "0.8.5"
2727
Functors = "0.4.9, 0.5"
2828
Statistics = "1"
29-
Zygote = "0.6.40, 0.7.1"
3029
julia = "1.10"
31-
32-
[extras]
33-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
34-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
35-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
36-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
37-
38-
[targets]
39-
test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"]

src/interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ macro def(expr)
273273
# Positional-argument method, has defaults for all but the first arg:
274274
positional = :(function $rule($(names[1]), $(params[2:end]...))
275275
$check_sign_eta
276-
vars = maybe_float.([$(names...)])
276+
vars = $(maybe_float).(($(names...)),($(default_types...)))
277277
return new{typeof.(vars)...}(vars...)
278278
end)
279279
# Keyword-argument method. (Made an inner constructor only to allow
@@ -283,5 +283,5 @@ macro def(expr)
283283
return esc(expr)
284284
end
285285

286-
maybe_float(x::Number) = float(x)
287-
maybe_float(x) = x
286+
maybe_float(x, T::Type{<:AbstractFloat}) = float(x)
287+
maybe_float(x, T) = x

src/rules.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,16 @@ gradients by an estimate their variance, instead of their second moment.
114114
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
115115
of the algorithm.
116116
"""
117-
struct RMSProp <: AbstractRule
118-
eta::Float64
119-
rho::Float64
120-
epsilon::Float64
117+
struct RMSProp{Teta,Trho,Teps} <: AbstractRule
118+
eta::Teta
119+
rho::Trho
120+
epsilon::Teps
121121
centred::Bool
122122
end
123123

124124
function RMSProp(η, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
125125
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
126-
RMSProp(η, ρ, ϵ, centred | centered)
126+
return RMSProp(float(η), float(ρ), float(ϵ), centred | centered)
127127
end
128128
RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, epsilon; kw...)
129129

@@ -155,7 +155,7 @@ end
155155

156156

157157
"""
158-
Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))
158+
Rprop(η = 1e-3, ℓ = (0.5, 1.2), Γ = (1e-6, 50.0))
159159
Rprop(; [eta, ell, gamma])
160160
161161
Optimizer using the
@@ -171,9 +171,9 @@ learning algorithm that depends only on the sign of the gradient.
171171
- Step sizes (`Γ::Tuple == gamma`): Mminimal and maximal allowed step sizes.
172172
"""
173173
@def struct Rprop <: AbstractRule
174-
eta = 1f-3
175-
ell = (5f-1, 1.2f0)
176-
gamma = (1f-6, 50f0)
174+
eta = 1e-3
175+
ell = (0.5, 1.2)
176+
gamma = (1e-6, 50.0)
177177
end
178178

179179
init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x))
@@ -528,17 +528,17 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`]
528528
The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`.
529529
See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details.
530530
"""
531-
struct AdamW <: AbstractRule
532-
eta::Float64
533-
beta::Tuple{Float64, Float64}
534-
lambda::Float64
535-
epsilon::Float64
531+
struct AdamW{Teta,Tbeta<:Tuple,Tlambda,Teps} <: AbstractRule
532+
eta::Teta
533+
beta::Tbeta
534+
lambda::Tlambda
535+
epsilon::Teps
536536
couple::Bool
537537
end
538538

539539
function AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8; couple::Bool = true)
540540
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
541-
AdamW(η, β, λ, ϵ, couple)
541+
return AdamW(float(η), β, float(λ), float(ϵ), couple)
542542
end
543543

544544
AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda= 0.0, epsilon = 1e-8, kw...) =
@@ -704,12 +704,12 @@ Typically composed with other rules using [`OptimiserChain`](@ref).
704704
705705
See also [`ClipGrad`](@ref).
706706
"""
707-
struct ClipNorm <: AbstractRule
708-
omega::Float64
709-
p::Float64
707+
struct ClipNorm{To,Tp} <: AbstractRule
708+
omega::To
709+
p::Tp
710710
throw::Bool
711711
end
712-
ClipNorm(ω = 10, p = 2; throw::Bool = true) = ClipNorm(ω, p, throw)
712+
ClipNorm(ω = 10, p = 2; throw::Bool = true) = ClipNorm(float(ω), float(p), throw)
713713

714714
init(o::ClipNorm, x::AbstractArray) = nothing
715715

test/Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[deps]
2+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
4+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
5+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
6+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/interface.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@testset "@def" begin
2+
Optimisers.@def struct DummyRule
3+
a = 1
4+
b1 = 1.5
5+
b2 = 2.5f0
6+
c = (1.0, 2.0)
7+
end
8+
9+
# no args
10+
r = DummyRule()
11+
@test typeof(r.a) == Int
12+
@test typeof(r.b1) == Float64
13+
@test typeof(r.b2) == Float32
14+
@test typeof(r.c) == Tuple{Float64, Float64}
15+
16+
# some positional args
17+
r = DummyRule(2, 2, 4.5)
18+
@test r.a == 2
19+
@test r.b1 == 2
20+
@test r.b2 == 4.5
21+
@test typeof(r.b1) == Float64 # int promoted to float
22+
@test typeof(r.b2) == Float64 # Float64 not converted to Float32
23+
@test r.c == (1.0, 2.0)
24+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,4 +557,7 @@ end
557557
@testset verbose=true "Optimisation Rules" begin
558558
include("rules.jl")
559559
end
560+
@testset verbose=true "interface" begin
561+
include("interface.jl")
562+
end
560563
end

0 commit comments

Comments
 (0)