diff --git a/Project.toml b/Project.toml index f94d33244b..d265515a45 100644 --- a/Project.toml +++ b/Project.toml @@ -104,7 +104,7 @@ MacroTools = "0.5.13" Markdown = "1.10" NCCL = "0.1.1" NNlib = "0.9.26" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Preferences = "1.4.3" Random = "1.10" Reactant = "0.2.55" diff --git a/docs/Project.toml b/docs/Project.toml index 831d3b1414..ee5220eb8a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -62,7 +62,7 @@ LuxLib = "1.3.4" LuxTestUtils = "1.5" MLDataDevices = "1.6.10" NNlib = "0.9.26" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Printf = "1.10" Random = "1.10" Reactant = "0.2.55" diff --git a/examples/Basics/Project.toml b/examples/Basics/Project.toml index 087eec95a7..bd3aac2436 100644 --- a/examples/Basics/Project.toml +++ b/examples/Basics/Project.toml @@ -11,5 +11,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ComponentArrays = "0.15.22" ForwardDiff = "0.10, 1" Lux = "1" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Zygote = "0.6.70, 0.7" diff --git a/examples/CIFAR10/Project.toml b/examples/CIFAR10/Project.toml index e7e90d1119..f74c0b65bf 100644 --- a/examples/CIFAR10/Project.toml +++ b/examples/CIFAR10/Project.toml @@ -35,7 +35,7 @@ LuxCUDA = "0.3.2" MLDatasets = "0.7.14" MLUtils = "0.4.4" OneHotArrays = "0.2.5" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Printf = "1.10" ProgressTables = "0.1" Random = "1.10" diff --git a/examples/ConvolutionalVAE/Project.toml b/examples/ConvolutionalVAE/Project.toml index 093def46e8..8d0ea12461 100644 --- a/examples/ConvolutionalVAE/Project.toml +++ b/examples/ConvolutionalVAE/Project.toml @@ -23,7 +23,7 @@ Lux = "1.4.1" MLDatasets = "0.7.18" MLUtils = "0.4.4" OneHotArrays = "0.2.6" -Optimisers = "0.4" +Optimisers = "0.4.6" Printf = "1.10" Random = "1.10" Reactant = "0.2.55" diff --git a/examples/DDIM/Project.toml b/examples/DDIM/Project.toml index de5e544bd2..7ecf7e3b89 100644 --- a/examples/DDIM/Project.toml +++ b/examples/DDIM/Project.toml @@ -36,7 +36,7 @@ JLD2 = "0.4.48, 0.5" Lux = "1" LuxCUDA = "0.3" MLUtils = "0.4" -Optimisers = "0.4.1" +Optimisers = "0.4.6" ParameterSchedulers = "0.4.1" ProgressBars = "1" Random = "1.10" diff --git a/examples/GCN_Cora/Project.toml b/examples/GCN_Cora/Project.toml index fada8bd20c..ce60e15891 100644 --- a/examples/GCN_Cora/Project.toml +++ b/examples/GCN_Cora/Project.toml @@ -18,7 +18,7 @@ GNNGraphs = "1" Lux = "1.5" MLDatasets = "0.7.18" OneHotArrays = "0.2" -Optimisers = "0.4.4" +Optimisers = "0.4.6" Printf = "1.10" Random = "1.10" Reactant = "0.2.55" diff --git a/examples/HyperNet/Project.toml b/examples/HyperNet/Project.toml index e017048ba6..1bfbe76707 100644 --- a/examples/HyperNet/Project.toml +++ b/examples/HyperNet/Project.toml @@ -15,5 +15,5 @@ Lux = "1" MLDatasets = "0.7" MLUtils = "0.4" OneHotArrays = "0.2.5" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Reactant = "0.2.55" diff --git a/examples/ImageNet/Project.toml b/examples/ImageNet/Project.toml index 96957a2ce1..b4bacc8591 100644 --- a/examples/ImageNet/Project.toml +++ b/examples/ImageNet/Project.toml @@ -38,7 +38,7 @@ MLUtils = "0.4.4" MPI = "0.20.21" NCCL = "0.1.1" OneHotArrays = "0.2.5" -Optimisers = "0.4.1" +Optimisers = "0.4.6" ParameterSchedulers = "0.4.2" Random = "1.10" Setfield = "1.1.1" diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index 99b7d7b297..ef804427b5 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -21,7 +21,7 @@ LuxCUDA = "0.3" MLDatasets = "0.7" MLUtils = "0.4" OneHotArrays = "0.2.5" -Optimisers = "0.4.1" +Optimisers = "0.4.6" OrdinaryDiffEqTsit5 = "1" SciMLSensitivity = "7.63" Statistics = "1" diff --git a/examples/PINN2DPDE/Project.toml b/examples/PINN2DPDE/Project.toml index 9dad2b3152..a4e4166ec5 100644 --- a/examples/PINN2DPDE/Project.toml +++ b/examples/PINN2DPDE/Project.toml @@ -18,7 +18,7 @@ Enzyme = "0.13" Lux = "1" MLUtils = "0.4.4" OnlineStats = "1.7.1" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Printf = "1.10" Random = "1.10" Reactant = "0.2.55" diff --git a/examples/PolynomialFitting/Project.toml b/examples/PolynomialFitting/Project.toml index 939eabd585..f1f2d82a33 100644 --- a/examples/PolynomialFitting/Project.toml +++ b/examples/PolynomialFitting/Project.toml @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ADTypes = "1.10" CairoMakie = "0.12, 0.13" Lux = "1" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Printf = "1.10" Random = "1.10" Reactant = "0.2.55" diff --git a/examples/RealNVP/Project.toml b/examples/RealNVP/Project.toml index a52474e29f..6c3b673b9f 100644 --- a/examples/RealNVP/Project.toml +++ b/examples/RealNVP/Project.toml @@ -16,7 +16,7 @@ ConcreteStructs = "0.2.3" Enzyme = "0.13.35" Lux = "1.5" MLUtils = "0.4.5" -Optimisers = "0.4.4" +Optimisers = "0.4.6" Printf = "1.10" Random = "1.10" Reactant = "0.2.55" diff --git a/examples/SimpleChains/Project.toml b/examples/SimpleChains/Project.toml index be0f63a0d2..7dbfcbe973 100644 --- a/examples/SimpleChains/Project.toml +++ b/examples/SimpleChains/Project.toml @@ -17,7 +17,7 @@ Lux = "1" MLDatasets = "0.7.14" MLUtils = "0.4" OneHotArrays = "0.2.5" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Random = "1" Reactant = "0.2.55" SimpleChains = "0.4.6" diff --git a/examples/SimpleRNN/Project.toml b/examples/SimpleRNN/Project.toml index df235d9131..083c6ece8f 100644 --- a/examples/SimpleRNN/Project.toml +++ b/examples/SimpleRNN/Project.toml @@ -13,5 +13,5 @@ ADTypes = "1.10" JLD2 = "0.5" Lux = "1" MLUtils = "0.4" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Reactant = "0.2.55" diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 76e1be5ab1..f8c8a81b28 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -2,7 +2,8 @@ module LuxReactantExt using Enzyme: Enzyme, Const using Optimisers: Optimisers -using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber +using Reactant: + Reactant, @compile, @code_hlo, @jit, AnyTracedRArray, TracedRArray, TracedRNumber using ReactantCore: ReactantCore, @trace using Setfield: @set! using Static: True, False diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index bc9ad74b20..6e94499fe8 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -2,3 +2,8 @@ Utils.vec(x::AnyTracedRArray) = ReactantCore.materialize_traced_array(vec(x)) # XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g + +# Optimisers setup +function Lux.ReactantCompatibleOptimisers.optimisers_setup_with_jit(opt, ps) + return @jit Optimisers.setup(opt, ps) +end diff --git a/src/helpers/optimizers.jl b/src/helpers/optimizers.jl index 69ebfc08a8..9fb837b3d8 100644 --- a/src/helpers/optimizers.jl +++ b/src/helpers/optimizers.jl @@ -1,163 +1,23 @@ -# This is mostly an internal implementation detail that users shouldn't need to worry about. -# We can remove this once https://github.com/FluxML/Optimisers.jl/issues/205 is resolved. +# We use this module mostly as a placeholder for patches that should be merged into +# Optimisers.jl for Reactant compatibility. module ReactantCompatibleOptimisers -using ConcreteStructs: @concrete -using Optimisers: Optimisers, AbstractRule -using Setfield: Setfield, @set! +using Optimisers: Optimisers using ..Lux: Lux, Utils -abstract type ReactantCompatibleOptimisersRule <: AbstractRule end - -function make_reactant_compatible(opt::AbstractRule) - @warn "`make_reactant_compatible` is not defined for $(opt). Returning the original \ - optimizer. This means adjusting learning rate and other parameters won't \ - reflect in the generated MLIR." maxlog = 1 - return opt -end -make_reactant_compatible(opt::ReactantCompatibleOptimisersRule) = opt - -function setfield_if_present(opt, field::Symbol, nt::NamedTuple) - if hasfield(typeof(nt), field) - return Setfield.set( - opt, - Setfield.PropertyLens{field}(), - convert(typeof(getproperty(opt, field)), getproperty(nt, field)), - ) - end - return opt -end - -function Optimisers._adjust(opt::ReactantCompatibleOptimisersRule, nt::NamedTuple) - for field in fieldnames(typeof(opt)) - opt = setfield_if_present(opt, field, nt) - end - return opt -end - -# OptimiserChain function make_reactant_compatible(opt::Optimisers.OptimiserChain) - return Optimisers.OptimiserChain(make_reactant_compatible.(opt.opts)) -end - -# Descent -@concrete struct ReactantDescent <: ReactantCompatibleOptimisersRule - eta -end - -function make_reactant_compatible(opt::Optimisers.Descent) - return ReactantDescent(Utils.to_rarray(opt.eta; track_numbers=true)) -end - -Optimisers.init(::ReactantDescent, ::AbstractArray) = nothing - -function Optimisers.apply!(opt::ReactantDescent, state, x::AbstractArray{T}, dx) where {T} - η = T(opt.eta) - return state, @. dx * η -end - -# Momentum -@concrete struct ReactantMomentum <: ReactantCompatibleOptimisersRule - eta - rho -end - -function make_reactant_compatible(opt::Optimisers.Momentum) - return ReactantMomentum( - Utils.to_rarray(opt.eta; track_numbers=true), - Utils.to_rarray(opt.rho; track_numbers=true), - ) -end - -function Optimisers.init(::ReactantMomentum, x::AbstractArray) - return Optimisers.init(Optimisers.Momentum(0.0, 0.0), x) -end - -function Optimisers.apply!(opt::ReactantMomentum, mvel, ::AbstractArray{T}, dx) where {T} - η, ρ = T(opt.eta), T(opt.rho) - @. mvel = ρ * mvel + η * dx - return mvel, mvel -end - -# Adam -@concrete struct ReactantAdam <: ReactantCompatibleOptimisersRule - eta - beta - epsilon + return Optimisers.OptimiserChain(make_reactant_compatible.(opt.opts)...) end -function make_reactant_compatible(opt::Optimisers.Adam) - return ReactantAdam( - Utils.to_rarray(opt.eta; track_numbers=true), - Utils.to_rarray(opt.beta; track_numbers=true), - Utils.to_rarray(opt.epsilon; track_numbers=true), - ) +function make_reactant_compatible(opt::Optimisers.AbstractRule) + return Utils.to_rarray(opt; track_numbers=AbstractFloat) end -function Optimisers.init(opt::ReactantAdam, x::AbstractArray{T}) where {T} - return ( - zero(x), - zero(x), - (Utils.promote_to(T, opt.beta[1]), Utils.promote_to(T, opt.beta[2])), - ) +function make_reactant_compatible(opt::Optimisers.AccumGrad) + return Utils.to_rarray(opt; track_numbers=Integer) end -function Optimisers.apply!(o::ReactantAdam, state, ::AbstractArray{T}, dx) where {T} - η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) # XXX: See Optimisers._eps - mt, vt, βt = state - - mt = @. β[1] * mt + (1 - β[1]) * dx - vt = @. β[2] * vt + (1 - β[2]) * abs2(dx) - dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η - - return (mt, vt, βt .* β), dx′ -end - -# AdamW -@concrete struct ReactantAdamW <: ReactantCompatibleOptimisersRule - eta - beta - lambda - epsilon - couple::Bool -end - -function make_reactant_compatible(opt::Optimisers.AdamW) - return ReactantAdamW( - Utils.to_rarray(opt.eta; track_numbers=true), - Utils.to_rarray(opt.beta; track_numbers=true), - Utils.to_rarray(opt.lambda; track_numbers=true), - Utils.to_rarray(opt.epsilon; track_numbers=true), - opt.couple, - ) -end - -function Optimisers.init(opt::ReactantAdamW, x::AbstractArray{T}) where {T} - return ( - zero(x), - zero(x), - (Utils.promote_to(T, opt.beta[1]), Utils.promote_to(T, opt.beta[2])), - ) -end - -function Optimisers.apply!(o::ReactantAdamW, state, x::AbstractArray{T}, dx) where {T} - η, β, ϵ, λ = T(o.eta), T.(o.beta), T(o.epsilon), T(o.lambda) # XXX: See Optimisers._eps - mt, vt, βt = state - - # standard Adam update with learning rate eta=1 - mt = @. β[1] * mt + (1 - β[1]) * dx - vt = @. β[2] * vt + (1 - β[2]) * abs2(dx) - dx′ = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) - - # apply learning rate and weight decay - if o.couple - dx′′ = @. η * (dx′ + λ * x) - else - dx′′ = @. η * dx′ + λ * x - end - - return (mt, vt, βt .* β), dx′′ -end +function optimisers_setup_with_jit end end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 537547a594..332a9b0abd 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -63,10 +63,11 @@ Constructor for [`TrainState`](@ref). [`TrainState`](@ref) object. """ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) - st_opt = if get_device_type(ps) <: ReactantDevice - Optimisers.setup(ReactantCompatibleOptimisers.make_reactant_compatible(optimizer), ps) + if get_device_type(ps) <: ReactantDevice + optimizer = ReactantCompatibleOptimisers.make_reactant_compatible(optimizer) + st_opt = ReactantCompatibleOptimisers.optimisers_setup_with_jit(optimizer, ps) else - Optimisers.setup(optimizer, ps) + st_opt = Optimisers.setup(optimizer, ps) end return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end diff --git a/test/Project.toml b/test/Project.toml index cd3cc1cd27..d95dff6dcf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -65,7 +65,7 @@ MLUtils = "0.4.3" NNlib = "0.9.26" Octavian = "0.3.28" OneHotArrays = "0.2.5" -Optimisers = "0.4.1" +Optimisers = "0.4.6" Pkg = "1.10" Preferences = "1.4.3" Random = "1.10" diff --git a/test/layers/embedding_tests.jl b/test/layers/embedding_tests.jl index e59ab70cb9..1aed1db311 100644 --- a/test/layers/embedding_tests.jl +++ b/test/layers/embedding_tests.jl @@ -116,7 +116,7 @@ end end end -@testitem "Reactant SinusoidalPositionalEmbedding" setup = [ +@testitem "Reactant: SinusoidalPositionalEmbedding" setup = [ SharedTestSetup, SharedReactantLayersTestSetup ] tags = [:reactant] begin using Reactant, Lux @@ -234,7 +234,7 @@ end end end -@testitem "Reactant RotaryPositionalEmbedding" setup = [ +@testitem "Reactant: RotaryPositionalEmbedding" setup = [ SharedTestSetup, SharedReactantLayersTestSetup ] tags = [:reactant] begin using Reactant, Lux diff --git a/test/reactant/training_tests.jl b/test/reactant/training_tests.jl index a869f0d108..00205f9016 100644 --- a/test/reactant/training_tests.jl +++ b/test/reactant/training_tests.jl @@ -47,8 +47,13 @@ inference_loss_fn_compiled(xᵢ, yᵢ, model, ps, st) end - @testset for opt in - (Descent(0.01f0), Momentum(0.01f0), Adam(0.01f0), AdamW(0.01f0)) + @testset for opt in ( + Descent(0.01f0), + Momentum(0.01f0), + Adam(0.01f0), + AdamW(0.01f0), + OptimiserChain(AccumGrad(5), Adam(0.01f0)), + ) train_state = Training.TrainState(model, ps, st, opt) for epoch in 1:100, (xᵢ, yᵢ) in dataloader @@ -76,3 +81,26 @@ end end end + +@testitem "Reactant Optimisers Patch: AccumGrad" tags = [:reactant] setup = [ + SharedTestSetup +] skip = :(Sys.iswindows()) begin + using Lux, Random, Reactant, Optimisers + + dev = reactant_device(; force=true) + + model = Chain( + Dense(2 => 4, relu), Chain(Dense(4 => 2, relu; use_bias=false), Dense(2 => 1)) + ) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + + x = randn(Float32, 2, 32) |> dev + + train_state = Training.TrainState( + model, ps, st, OptimiserChain(AccumGrad(5), Descent(0.1)) + ) + st_opt = train_state.optimizer_state + + hlo = repr(@code_hlo(Optimisers.update(st_opt, ps, ps))) + @test length(findall("stablehlo.if", hlo)) == (2 + 1 + 2) * 2 +end