Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ MacroTools = "0.5.13"
Markdown = "1.10"
NCCL = "0.1.1"
NNlib = "0.9.26"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.55"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ LuxLib = "1.3.4"
LuxTestUtils = "1.5"
MLDataDevices = "1.6.10"
NNlib = "0.9.26"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.55"
Expand Down
2 changes: 1 addition & 1 deletion examples/Basics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ComponentArrays = "0.15.22"
ForwardDiff = "0.10, 1"
Lux = "1"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Zygote = "0.6.70, 0.7"
2 changes: 1 addition & 1 deletion examples/CIFAR10/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ LuxCUDA = "0.3.2"
MLDatasets = "0.7.14"
MLUtils = "0.4.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Printf = "1.10"
ProgressTables = "0.1"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion examples/ConvolutionalVAE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Lux = "1.4.1"
MLDatasets = "0.7.18"
MLUtils = "0.4.4"
OneHotArrays = "0.2.6"
Optimisers = "0.4"
Optimisers = "0.4.6"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.55"
2 changes: 1 addition & 1 deletion examples/DDIM/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ JLD2 = "0.4.48, 0.5"
Lux = "1"
LuxCUDA = "0.3"
MLUtils = "0.4"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
ParameterSchedulers = "0.4.1"
ProgressBars = "1"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion examples/GCN_Cora/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ GNNGraphs = "1"
Lux = "1.5"
MLDatasets = "0.7.18"
OneHotArrays = "0.2"
Optimisers = "0.4.4"
Optimisers = "0.4.6"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.55"
Expand Down
2 changes: 1 addition & 1 deletion examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ Lux = "1"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Reactant = "0.2.55"
2 changes: 1 addition & 1 deletion examples/ImageNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ MLUtils = "0.4.4"
MPI = "0.20.21"
NCCL = "0.1.1"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
ParameterSchedulers = "0.4.2"
Random = "1.10"
Setfield = "1.1.1"
Expand Down
2 changes: 1 addition & 1 deletion examples/NeuralODE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
OrdinaryDiffEqTsit5 = "1"
SciMLSensitivity = "7.63"
Statistics = "1"
Expand Down
2 changes: 1 addition & 1 deletion examples/PINN2DPDE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Enzyme = "0.13"
Lux = "1"
MLUtils = "0.4.4"
OnlineStats = "1.7.1"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.55"
Expand Down
2 changes: 1 addition & 1 deletion examples/PolynomialFitting/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ADTypes = "1.10"
CairoMakie = "0.12, 0.13"
Lux = "1"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.55"
Expand Down
2 changes: 1 addition & 1 deletion examples/RealNVP/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ConcreteStructs = "0.2.3"
Enzyme = "0.13.35"
Lux = "1.5"
MLUtils = "0.4.5"
Optimisers = "0.4.4"
Optimisers = "0.4.6"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.55"
Expand Down
2 changes: 1 addition & 1 deletion examples/SimpleChains/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Lux = "1"
MLDatasets = "0.7.14"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Random = "1"
Reactant = "0.2.55"
SimpleChains = "0.4.6"
Expand Down
2 changes: 1 addition & 1 deletion examples/SimpleRNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ ADTypes = "1.10"
JLD2 = "0.5"
Lux = "1"
MLUtils = "0.4"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Reactant = "0.2.55"
3 changes: 2 additions & 1 deletion ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module LuxReactantExt

using Enzyme: Enzyme, Const
using Optimisers: Optimisers
using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber
using Reactant:
Reactant, @compile, @code_hlo, @jit, AnyTracedRArray, TracedRArray, TracedRNumber
using ReactantCore: ReactantCore, @trace
using Setfield: @set!
using Static: True, False
Expand Down
5 changes: 5 additions & 0 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@ Utils.vec(x::AnyTracedRArray) = ReactantCore.materialize_traced_array(vec(x))

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g

# Optimisers setup
function Lux.ReactantCompatibleOptimisers.optimisers_setup_with_jit(opt, ps)
return @jit Optimisers.setup(opt, ps)
end
158 changes: 9 additions & 149 deletions src/helpers/optimizers.jl
Original file line number Diff line number Diff line change
@@ -1,163 +1,23 @@
# This is mostly an internal implementation detail that users shouldn't need to worry about.
# We can remove this once https://github.com/FluxML/Optimisers.jl/issues/205 is resolved.
# We use this module mostly as a placeholder for patches that should be merged into
# Optimisers.jl for Reactant compatibility.
module ReactantCompatibleOptimisers

using ConcreteStructs: @concrete
using Optimisers: Optimisers, AbstractRule
using Setfield: Setfield, @set!
using Optimisers: Optimisers

using ..Lux: Lux, Utils

abstract type ReactantCompatibleOptimisersRule <: AbstractRule end

function make_reactant_compatible(opt::AbstractRule)
@warn "`make_reactant_compatible` is not defined for $(opt). Returning the original \
optimizer. This means adjusting learning rate and other parameters won't \
reflect in the generated MLIR." maxlog = 1
return opt
end
make_reactant_compatible(opt::ReactantCompatibleOptimisersRule) = opt

function setfield_if_present(opt, field::Symbol, nt::NamedTuple)
if hasfield(typeof(nt), field)
return Setfield.set(
opt,
Setfield.PropertyLens{field}(),
convert(typeof(getproperty(opt, field)), getproperty(nt, field)),
)
end
return opt
end

function Optimisers._adjust(opt::ReactantCompatibleOptimisersRule, nt::NamedTuple)
for field in fieldnames(typeof(opt))
opt = setfield_if_present(opt, field, nt)
end
return opt
end

# OptimiserChain
function make_reactant_compatible(opt::Optimisers.OptimiserChain)
return Optimisers.OptimiserChain(make_reactant_compatible.(opt.opts))
end

# Descent
@concrete struct ReactantDescent <: ReactantCompatibleOptimisersRule
eta
end

function make_reactant_compatible(opt::Optimisers.Descent)
return ReactantDescent(Utils.to_rarray(opt.eta; track_numbers=true))
end

Optimisers.init(::ReactantDescent, ::AbstractArray) = nothing

function Optimisers.apply!(opt::ReactantDescent, state, x::AbstractArray{T}, dx) where {T}
η = T(opt.eta)
return state, @. dx * η
end

# Momentum
@concrete struct ReactantMomentum <: ReactantCompatibleOptimisersRule
eta
rho
end

function make_reactant_compatible(opt::Optimisers.Momentum)
return ReactantMomentum(
Utils.to_rarray(opt.eta; track_numbers=true),
Utils.to_rarray(opt.rho; track_numbers=true),
)
end

function Optimisers.init(::ReactantMomentum, x::AbstractArray)
return Optimisers.init(Optimisers.Momentum(0.0, 0.0), x)
end

function Optimisers.apply!(opt::ReactantMomentum, mvel, ::AbstractArray{T}, dx) where {T}
η, ρ = T(opt.eta), T(opt.rho)
@. mvel = ρ * mvel + η * dx
return mvel, mvel
end

# Adam
@concrete struct ReactantAdam <: ReactantCompatibleOptimisersRule
eta
beta
epsilon
return Optimisers.OptimiserChain(make_reactant_compatible.(opt.opts)...)
end

function make_reactant_compatible(opt::Optimisers.Adam)
return ReactantAdam(
Utils.to_rarray(opt.eta; track_numbers=true),
Utils.to_rarray(opt.beta; track_numbers=true),
Utils.to_rarray(opt.epsilon; track_numbers=true),
)
function make_reactant_compatible(opt::Optimisers.AbstractRule)
return Utils.to_rarray(opt; track_numbers=AbstractFloat)
end

function Optimisers.init(opt::ReactantAdam, x::AbstractArray{T}) where {T}
return (
zero(x),
zero(x),
(Utils.promote_to(T, opt.beta[1]), Utils.promote_to(T, opt.beta[2])),
)
function make_reactant_compatible(opt::Optimisers.AccumGrad)
return Utils.to_rarray(opt; track_numbers=Integer)
end

function Optimisers.apply!(o::ReactantAdam, state, ::AbstractArray{T}, dx) where {T}
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) # XXX: See Optimisers._eps
mt, vt, βt = state

mt = @. β[1] * mt + (1 - β[1]) * dx
vt = @. β[2] * vt + (1 - β[2]) * abs2(dx)
dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η

return (mt, vt, βt .* β), dx′
end

# AdamW
@concrete struct ReactantAdamW <: ReactantCompatibleOptimisersRule
eta
beta
lambda
epsilon
couple::Bool
end

function make_reactant_compatible(opt::Optimisers.AdamW)
return ReactantAdamW(
Utils.to_rarray(opt.eta; track_numbers=true),
Utils.to_rarray(opt.beta; track_numbers=true),
Utils.to_rarray(opt.lambda; track_numbers=true),
Utils.to_rarray(opt.epsilon; track_numbers=true),
opt.couple,
)
end

function Optimisers.init(opt::ReactantAdamW, x::AbstractArray{T}) where {T}
return (
zero(x),
zero(x),
(Utils.promote_to(T, opt.beta[1]), Utils.promote_to(T, opt.beta[2])),
)
end

function Optimisers.apply!(o::ReactantAdamW, state, x::AbstractArray{T}, dx) where {T}
η, β, ϵ, λ = T(o.eta), T.(o.beta), T(o.epsilon), T(o.lambda) # XXX: See Optimisers._eps
mt, vt, βt = state

# standard Adam update with learning rate eta=1
mt = @. β[1] * mt + (1 - β[1]) * dx
vt = @. β[2] * vt + (1 - β[2]) * abs2(dx)
dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)

# apply learning rate and weight decay
if o.couple
dx′′ = @. η * (dx′ + λ * x)
else
dx′′ = @. η * dx′ + λ * x
end

return (mt, vt, βt .* β), dx′′
end
function optimisers_setup_with_jit end

end
7 changes: 4 additions & 3 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ Constructor for [`TrainState`](@ref).
[`TrainState`](@ref) object.
"""
function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
st_opt = if get_device_type(ps) <: ReactantDevice
Optimisers.setup(ReactantCompatibleOptimisers.make_reactant_compatible(optimizer), ps)
if get_device_type(ps) <: ReactantDevice
optimizer = ReactantCompatibleOptimisers.make_reactant_compatible(optimizer)
st_opt = ReactantCompatibleOptimisers.optimisers_setup_with_jit(optimizer, ps)
else
Optimisers.setup(optimizer, ps)
st_opt = Optimisers.setup(optimizer, ps)
end
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ MLUtils = "0.4.3"
NNlib = "0.9.26"
Octavian = "0.3.28"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Optimisers = "0.4.6"
Pkg = "1.10"
Preferences = "1.4.3"
Random = "1.10"
Expand Down
4 changes: 2 additions & 2 deletions test/layers/embedding_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end
end
end

@testitem "Reactant SinusoidalPositionalEmbedding" setup = [
@testitem "Reactant: SinusoidalPositionalEmbedding" setup = [
SharedTestSetup, SharedReactantLayersTestSetup
] tags = [:reactant] begin
using Reactant, Lux
Expand Down Expand Up @@ -234,7 +234,7 @@ end
end
end

@testitem "Reactant RotaryPositionalEmbedding" setup = [
@testitem "Reactant: RotaryPositionalEmbedding" setup = [
SharedTestSetup, SharedReactantLayersTestSetup
] tags = [:reactant] begin
using Reactant, Lux
Expand Down
Loading
Loading