Skip to content

Conversation

@mcabbott
Copy link
Member

Quick sketch of one way to easily allow different rules for different arrays, by modifying setup -- see docstring.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@AntonOresten
Copy link

I think this is elegant and useful. I'm working on some improvements to #203. Muon is optimal for linear layers, but doesn't make as much sense for e.g. Flux.Embedding, even though it is linear-like, and since the linear decoder layer in LLMs is often tied to the input encoder layer, it's preferable to disable Muon for that layer as well. I imagine the cleanest way of differentiating between linear layers is with an IdDict inside the setup rule function. You'd for example create the function based on which layers are present in some IdDict, and in the same function embed rules for different array shapes.

@AntonOresten
Copy link

AntonOresten commented Oct 22, 2025

One could do something like:

function fun_rule(model, rule=Muon(), fallback=Adam())
    skipped = Base.IdSet{Any}([model.encode.weight, model.decode.weight])
    fun(x::AbstractVector) = fallback
    fun(x::AbstractArray) = x in skipped ? fallback : rule
    return fun
end

opt_state = Optimisers.setup(fun_rule(model), model)

such that:

julia> model = (;
           encode=(; weight=rand(2,2)),
           other=(; weight=rand(2,2), bias=rand(2)),
           decode=(; weight=rand(2,2)));

julia> fun_rule(model)(model.encode.weight)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)

julia> fun_rule(model)(model.other.weight)
Muon(0.02, 0.95, 0.01, 1.0e-7, true)

julia> fun_rule(model)(model.other.bias)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)

julia> fun_rule(model)(model.decode.weight)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)

I generally avoid closures, but this has a certain elegance to it. Base.IdSet is private, but the alternative is slightly cursed:

skipped = keys(IdDict([model.encode.weight, model.decode.weight] .=> nothing))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants