From e5b85bcff14c7c966318b2743589d24589f027a0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 21 Dec 2024 14:42:19 -0500 Subject: [PATCH] add setup(::Function, model) --- src/Optimisers.jl | 17 ++++++++++++++--- src/interface.jl | 10 ++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 99fc162f..289ae850 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -78,17 +78,28 @@ init """ Optimisers.setup(rule, model) -> state_tree + Optimisers.setup(f::Function, model) -> state_tree Initialises the given optimiser for every trainable parameter within the model. Returns a tree of the relevant states, which must be passed to [`update`](@ref Optimisers.update) or [`update!`](@ref Optimisers.update!). # Example -```jldoctest -julia> m = (x = rand(3), y = (true, false), z = tanh); +```jldoctest setup1 +julia> m = (x = rand(3), y = (true, false), z = randn(2, 2)); julia> Optimisers.setup(Momentum(), m) # same field names as m -(x = Leaf(Momentum(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ()) +(x = Leaf(Momentum(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = Leaf(Momentum(0.01, 0.9), [0.0 0.0; 0.0 0.0])) +``` + +The method accepting a function `f(x::AbstractArray)::AbstractRule` lets you use different +optimisation rules on different trainable arrays, by `size` or `ndims` or other properties: + +```jldoctest setup1 +julia> Optimisers.setup(m) do a + ndims(a) == 1 ? Descent() : Adam() + end +(x = Leaf(Descent(0.1), nothing), y = ((), ()), z = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], (0.9, 0.999)))) ``` The recursion into structures uses Functors.jl, and any new `struct`s containing parameters diff --git a/src/interface.jl b/src/interface.jl index e44dec15..08095ccf 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -26,17 +26,19 @@ Leaf(rule, state; frozen::Bool = false) = Leaf(rule, state, frozen) Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b) -function setup(rule::AbstractRule, model) +setup(rule::AbstractRule, model) = setup(Returns(rule), model) +function setup(fun::Function, model) cache = IdDict() - tree = _setup(rule, model; cache) + tree = _setup(fun, model; cache) isempty(cache) && @warn "setup found no trainable parameters in this model" tree end # _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc. -function _setup(rule, x; cache) +function _setup(fun::Function, x; cache) haskey(cache, x) && return cache[x] if isnumeric(x) + rule = fun(x)::AbstractRule ℓ = Leaf(rule, init(rule, x)) if isbits(x) cache[nothing] = nothing # just to disable the warning @@ -45,7 +47,7 @@ function _setup(rule, x; cache) cache[x] = ℓ end else - mapvalue(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x)) + mapvalue(xᵢ -> _setup(fun, xᵢ; cache), _trainable(x)) end end