|
| 1 | +module OptimisersEnzymeCoreExt |
| 2 | + |
| 3 | +import Optimisers: trainable, setup, update!, isnumeric, AbstractRule, _setup |
| 4 | +import EnzymeCore: Duplicated, Const |
| 5 | + |
| 6 | +using Functors: fmapstructure |
| 7 | + |
| 8 | +trainable(x::Duplicated) = (; val = x.val) |
| 9 | +trainable(x::Const) = (;) |
| 10 | + |
| 11 | +""" |
| 12 | + setup(rule::AbstractRule, model_grad::Duplicated) |
| 13 | +
|
| 14 | +For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`. |
| 15 | +""" |
| 16 | +setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val) |
| 17 | + |
| 18 | +_setup(rule, x::Duplicated; cache) = throw(ArgumentError( |
| 19 | + """Objects of type `Duplicated` are only supported by Optimisers.jl at top level, |
| 20 | + they may not appear deep inside other objects.""" |
| 21 | +)) |
| 22 | + |
| 23 | +""" |
| 24 | + update!(opt_state, model_grad::Duplicated) |
| 25 | +
|
| 26 | +For use with Enzyme's `Duplicated`, which holds both a model/parameters |
| 27 | +and the corresponding gradient. |
| 28 | +
|
| 29 | +# Example |
| 30 | +
|
| 31 | +```jldoctest |
| 32 | +julia> using Optimisers, EnzymeCore |
| 33 | +
|
| 34 | +julia> x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4]) |
| 35 | +Duplicated{Vector{Float16}}(Float16[1.0, 2.0, 3.0], Float16[1.0, 0.0, -4.0]) |
| 36 | +
|
| 37 | +julia> st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx |
| 38 | +Leaf(Momentum(0.111111, 0.9), Float16[0.0, 0.0, 0.0]) |
| 39 | +
|
| 40 | +julia> Optimisers.update!(st, x_dx) # mutates both arguments |
| 41 | +
|
| 42 | +julia> x_dx |
| 43 | +Duplicated{Vector{Float16}}(Float16[0.8887, 2.0, 3.445], Float16[1.0, 0.0, -4.0]) |
| 44 | +
|
| 45 | +julia> st |
| 46 | +Leaf(Momentum(0.111111, 0.9), Float16[0.1111, 0.0, -0.4443]) |
| 47 | +``` |
| 48 | +""" |
| 49 | +function update!(opt_state, model_grad::Duplicated) |
| 50 | + _, _ = update!(opt_state, model_grad.val, _grad_or_nothing(model_grad)) |
| 51 | + nothing |
| 52 | +end |
| 53 | + |
| 54 | +# This function strips the returned gradient to be Zygote-like, |
| 55 | +# most importantly prune=nothing removes 2nd appearance of shared gradient to avoid double-counting. |
| 56 | +_grad_or_nothing(dup::Duplicated) = fmapstructure(_grad_or_nothing, dup.dval; prune=nothing) |
| 57 | +_grad_or_nothing(::Const) = nothing |
| 58 | +_grad_or_nothing(x) = isnumeric(x) ? x : nothing |
| 59 | + |
| 60 | +end |
0 commit comments