diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 99fc162f..0bce300d 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -1,9 +1,10 @@ module Optimisers -using Functors: functor, fmap, fmap_with_path, +using Functors: functor, fmap, fmap_with_path, KeyPath, haskeypath, getkeypath, isleaf, @functor, fmapstructure, children, AbstractWalk using LinearAlgebra +import LinearAlgebra: norm include("interface.jl") export AbstractRule @@ -23,7 +24,7 @@ include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, - AccumGrad + AccumGrad, Muon VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!")) diff --git a/src/rules.jl b/src/rules.jl index 0cd8d30c..db74e8ed 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -181,7 +181,7 @@ init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x)) function apply!(o::Rprop, state, x::AbstractArray{T}, dx) where T ℓ, Γ = T.(o.ell), T.(o.gamma) g, η = state - + η = broadcast(g, η, dx) do g, η, dx g * dx > 0 ? min(η * ℓ[2], Γ[2]) : g * dx < 0 ? max(η * ℓ[1], Γ[1]) : η end @@ -256,6 +256,7 @@ function apply!(o::Lion, state, x::AbstractArray{T}, dx) where T return state, dx′ end + """ RAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8) RAdam(; [eta, beta, epsilon]) @@ -599,14 +600,89 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T return (mt, st, βt .* β), dx′ end +nonfirstdims(x) = prod(size(x)[2:end]) + +""" + Muon(opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01), η = 0.02, μ = 0.95, λ = 0.01, fallback = Returns(false)) + Muon(; [opt, eta, mu, lambda, fallback]) + +Muon - MomentUm Orthogonalized by Newton-schulz (https://github.com/KellerJordan/Muon) + +Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, +in which each 2D parameter's update is replaced with the nearest orthogonal matrix using Newton-Schulz iteration. + +# Parameters +- Fallback optimizer (`opt`): Optimizer to use for 1D parameters or when the `fallback` function returns true +- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights +- Momentum (`μ == mu`): Controls the acceleration of gradient descent in the prominent direction +- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation. +- Fallback function (`fallback`): Function to control when, in addition to 1D arrays, the fallback optimizer should be used. Will be passed the parameter array and must return a boolean. + +Note: Works best with large batch sizes and may not be suitable for fine-tuning. +In nanoGPT speedrun experiments, Muon is used for the internal layer >2D weights, and AdamW is used for the 1D weights, embeddings, and heads. + +`Optimisers.adjust!(optimiser_state, η::Real)` will adjust the fallback optimizer's `eta` to `η * (opt.eta / eta)`, and Muon's `eta` to `η`, preserving their ratio, +but `Optimisers.adjust!(optimiser, eta = η)` will only adjust Muon's learning rate (allowing you to adjust the fallback optimizer's learning rate separately). +""" +struct Muon <: AbstractRule + opt::AbstractRule + eta::Float64 + mu::Float64 + lambda::Float64 + fallback::Function +end + +Muon(;opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01), eta = 0.02, mu = 0.95, lambda = 0.01, fallback = x -> false) = Muon(opt, eta, mu, lambda, fallback) + +function init(o::Muon, x::AbstractArray) + if nonfirstdims(x) == 1 || o.fallback(x) + return init(o.opt, x) + else + return zero(x) + end +end + +function apply!(o::Muon, state, x::AbstractArray{T}, dx) where T + if nonfirstdims(x) == 1 || o.fallback(x) + return apply!(o.opt, state, x, dx) + else + η, μ, λ = T(o.eta), T(o.mu), T(o.lambda) + @.. state = μ * state + dx + Ot = _newton_schulz5(μ .* state .+ dx) * T(sqrt(max(1, size(x,1)/nonfirstdims(x)))) + dx′ = @lazy η * (Ot + λ * x) + return state, dx′ + end +end + +function _inner_newton_schulz5(X::AbstractMatrix{T}) where T + a, b, c = (T(3.4445f0), T(-4.7750f0), T(2.0315f0)) + for _ in 1:5 + A = X * X' + B = b * A + c * A * A + X = a * X + B * X + end + X +end +function _newton_schulz5(G::AbstractMatrix{T}) where T + X = G / (norm(G) + eps(T)) + if size(G, 1) > size(G, 2) + transpose(_inner_newton_schulz5(transpose(X))) + else + _inner_newton_schulz5(X) + end +end +_newton_schulz5(G::AbstractArray) = reshape(_newton_schulz5(reshape(G, size(G,1), :)), size(G)) + +adjust(r::Muon, η::Real) = adjust(r, eta = η, opt = adjust(r.opt, eta = (r.opt.eta / r.eta) * η)) + """ WeightDecay(λ = 5e-4) WeightDecay(; [lambda]) -Implements ``L_2`` regularisation, also known as ridge regression, +Implements ``L_2`` regularisation, also known as ridge regression, when composed with other rules as the first transformation in an [`OptimiserChain`](@ref). -It does this by adding `λ .* x` to the gradient. This is equivalent to adding +It does this by adding `λ .* x` to the gradient. This is equivalent to adding `λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss. See also [`SignDecay`] for ``L_1`` normalisation. @@ -644,7 +720,7 @@ function adjust(r::WeightDecay; gamma = nothing, kw...) Implements ``L_1`` regularisation, also known as LASSO regression, when composed with other rules as the first transformation in an [`OptimiserChain`](@ref). -It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding +It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding `λ * sum(abs, x) == λ * norm(x, 1)` to the loss. See also [`WeightDecay`] for ``L_2`` normalisation. @@ -783,7 +859,7 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...) foldl(tuple.(o.opts, states); init = ((), dx)) do (states′, dx′), (opt, state) if dx′ isa Zero return (states′..., state), dx′ - else + else state′, dx′ = apply!(opt, state, x, dx′, dxs...) return (states′..., state′), dx′ end @@ -831,10 +907,10 @@ julia> m # n=2 gradients applied at once """ struct AccumGrad <: AbstractRule n::Int - + function AccumGrad(n::Int) n > 0 || throw(ArgumentError("AccumGrad must accumulate at least one gradient")) - return new(n) + return new(n) end end diff --git a/test/rules.jl b/test/rules.jl index 499902ca..2cbc3efc 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,7 +8,7 @@ RULES = [ # All the rules at default settings: Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), - AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), + AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), Muon(), # A few chained combinations: OptimiserChain(SignDecay(0.001), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), @@ -183,7 +183,7 @@ end # The Flux PR had 1e-2 for all. But AdaDelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2), # These weren't in Flux PR: - Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2), + Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2), ] # Our "model" is just a complex number model = (w = zeros(ComplexF64, 1),) @@ -226,7 +226,7 @@ end @test static_loss(static_model) < last_loss last_loss = static_loss(static_model) end - @test static_loss(static_model) < 1.9 + @test static_loss(static_model) < 1.9 end end @@ -254,16 +254,16 @@ end g1 = rand(5) tree, x1 = Optimisers.update(tree, x, g1) @test x1 ≈ x - @test x1 ≈ x0 + @test x1 ≈ x0 g2 = rand(5) tree, x2 = Optimisers.update(tree, x1, g2) @test x2 ≈ x - @test x2 ≈ x0 + @test x2 ≈ x0 g3 = rand(5) tree, x3 = Optimisers.update(tree, x2, g3) @test x3 ≈ x0 .- lr .* (g1 .+ g2 .+ g3) ./ 3 g4 = rand(5) - + tree, x4 = Optimisers.update(tree, x3, g4) @test x4 ≈ x3 end