Skip to content

Commit 2639523

Browse files
authored
Add Duplicated methods (#192)
* add Duplicated methods * add test * test for shared params + minimal docs * remove 1.6 CI * indent by two spaces * fix doctest
1 parent 38c9d62 commit 2639523

File tree

5 files changed

+104
-5
lines changed

5 files changed

+104
-5
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ jobs:
1515
fail-fast: false
1616
matrix:
1717
version:
18-
- '1.6'
1918
- '1'
2019
- 'nightly'
20+
- "1.10"
2121
os:
2222
- ubuntu-latest
2323
arch:

Project.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
3+
version = "0.4.1"
34
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.4.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
89
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1213

14+
[weakdeps]
15+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
16+
17+
[extensions]
18+
OptimisersEnzymeCoreExt = "EnzymeCore"
19+
1320
[compat]
1421
ChainRulesCore = "1"
22+
EnzymeCore = "0.8.5"
1523
Functors = "0.4.9, 0.5"
1624
Statistics = "1"
1725
Zygote = "0.6.40"
18-
julia = "1.6"
26+
julia = "1.10"
1927

2028
[extras]
2129
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2230
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2331
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2432

2533
[targets]
26-
test = ["Test", "StaticArrays", "Zygote"]
34+
test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"]

docs/src/index.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,12 @@ julia> Optimisers.update!(opt_state, x, g);
358358
julia> opt_state # the state in `a` and `b` differ
359359
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.09, 0.09], [0.000999, 0.000999], (0.729, 0.997003))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))))
360360
```
361+
362+
## Usage with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)
363+
364+
Enzyme.jl is a new automatic differentiation package, an alternative to Zygote.jl.
365+
It likes to store the model and the gradient together, as an object `Duplicated(x, dx)`.
366+
367+
Optimisers.jl now has some methods to handle this:
368+
* `update!(opt_state, Duplicated(model, grad))` uses the gradient to update both the model and the optimiser state, and
369+
* `setup(::AbstractRule, ::Duplicated)` ignores the gradient and returns `setup(rule, model)`.

ext/OptimisersEnzymeCoreExt.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module OptimisersEnzymeCoreExt
2+
3+
import Optimisers: trainable, setup, update!, isnumeric, AbstractRule, _setup
4+
import EnzymeCore: Duplicated, Const
5+
6+
using Functors: fmapstructure
7+
8+
trainable(x::Duplicated) = (; val = x.val)
9+
trainable(x::Const) = (;)
10+
11+
"""
12+
setup(rule::AbstractRule, model_grad::Duplicated)
13+
14+
For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`.
15+
"""
16+
setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val)
17+
18+
_setup(rule, x::Duplicated; cache) = throw(ArgumentError(
19+
"""Objects of type `Duplicated` are only supported by Optimisers.jl at top level,
20+
they may not appear deep inside other objects."""
21+
))
22+
23+
"""
24+
update!(opt_state, model_grad::Duplicated)
25+
26+
For use with Enzyme's `Duplicated`, which holds both a model/parameters
27+
and the corresponding gradient.
28+
29+
# Example
30+
31+
```jldoctest
32+
julia> using Optimisers, EnzymeCore
33+
34+
julia> x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
35+
Duplicated{Vector{Float16}}(Float16[1.0, 2.0, 3.0], Float16[1.0, 0.0, -4.0])
36+
37+
julia> st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
38+
Leaf(Momentum(0.111111, 0.9), Float16[0.0, 0.0, 0.0])
39+
40+
julia> Optimisers.update!(st, x_dx) # mutates both arguments
41+
42+
julia> x_dx
43+
Duplicated{Vector{Float16}}(Float16[0.8887, 2.0, 3.445], Float16[1.0, 0.0, -4.0])
44+
45+
julia> st
46+
Leaf(Momentum(0.111111, 0.9), Float16[0.1111, 0.0, -0.4443])
47+
```
48+
"""
49+
function update!(opt_state, model_grad::Duplicated)
50+
_, _ = update!(opt_state, model_grad.val, _grad_or_nothing(model_grad))
51+
nothing
52+
end
53+
54+
# This function strips the returned gradient to be Zygote-like,
55+
# most importantly prune=nothing removes 2nd appearance of shared gradient to avoid double-counting.
56+
_grad_or_nothing(dup::Duplicated) = fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
57+
_grad_or_nothing(::Const) = nothing
58+
_grad_or_nothing(x) = isnumeric(x) ? x : nothing
59+
60+
end

test/runtests.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Optimisers
2-
using ChainRulesCore, Functors, StaticArrays, Zygote
2+
using ChainRulesCore, Functors, StaticArrays, Zygote, EnzymeCore
33
using LinearAlgebra, Statistics, Test, Random
44
using Optimisers: @.., @lazy
55
using Base.Broadcast: broadcasted, instantiate, Broadcasted
@@ -534,6 +534,28 @@ end
534534
@test Optimisers._norm(bc2, p) isa Float64
535535
end
536536
end
537+
538+
@testset "Enzyme Duplicated" begin
539+
x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
540+
st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
541+
@test st isa Optimisers.Leaf
542+
@test nothing === Optimisers.update!(st, x_dx) # mutates both arguments
543+
@test x_dx.val Float16[0.8887, 2.0, 3.445]
544+
545+
shared = [1.0]
546+
model = (x=shared, y=shared)
547+
grad = deepcopy(model) # Enzyme produces something like this, grad.x === grad.y, already accumulated.
548+
dup = Duplicated(model, model)
549+
st2 = Optimisers.setup(Descent(0.1), model)
550+
Optimisers.update!(st2, dup)
551+
@test model.x [0.9]
552+
shared .= 1
553+
Optimisers.update!(st2, model, grad)
554+
model.x [0.8] # This is wrong, but don't make it a test.
555+
# Ideally, perhaps the 3-arg update! could notice that grad.x===grad.y, and not accumulate the gradient in this case?
556+
557+
@test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) # Duplicated deep inside is not allowed
558+
end
537559
end
538560
@testset verbose=true "Destructure" begin
539561
include("destructure.jl")

0 commit comments

Comments
 (0)