Skip to content

Commit 9928588

Browse files
vpuri3ToucheSirmcabbott
authored
Add Adapt.adapt_structure method for Optimisers.Leaf (#180)
* Adapt.adapt_structure method for Optimisers.Leaf * import Adapt.jl * add Adapt.jl to Project.toml * adapt compat * based on discussion: adapt_structure method does not maintain IdDict handled by functors. So we add a warning referring the user to Flux.gpu or MLDataDevices.gpu_device() * Update ext/OptimisersAdaptExt.jl Co-authored-by: Brian Chen <[email protected]> * edit warning to indicate that this is a correctness issue * Update ext/OptimisersAdaptExt.jl Co-authored-by: Michael Abbott <[email protected]> --------- Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Michael Abbott <[email protected]>
1 parent d842ddb commit 9928588

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[weakdeps]
14+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1415
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1516

1617
[extensions]
18+
OptimisersAdaptExt = ["Adapt"]
1719
OptimisersEnzymeCoreExt = "EnzymeCore"
1820

1921
[compat]
22+
Adapt = "4"
2023
ChainRulesCore = "1"
2124
EnzymeCore = "0.8.5"
2225
Functors = "0.4.9, 0.5"

ext/OptimisersAdaptExt.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module OptimisersAdaptExt
2+
3+
import Adapt
4+
import Optimisers: Leaf
5+
6+
function Adapt.adapt_structure(to, leaf::Leaf)
7+
@warn """`Optimisers.Leaf` object does not support device transfer via
8+
`Adapt.jl`. This is because `Adapt.jl` does not handle shared parameters (i.e. the same parameter array
9+
appearing more than once in the model), and in such cases this will lead to incorrect gradient updates.
10+
Avoid this by calling `Flux.gpu/cpu` or `MLDataDevices.cpu_device()/gpu_device()` on the
11+
optimiser state object.
12+
""" maxlog=1
13+
14+
rule = Adapt.adapt(to, leaf.rule)
15+
state = Adapt.adapt(to, leaf.state)
16+
17+
Leaf(rule, state, leaf.frozen)
18+
end
19+
20+
end

0 commit comments

Comments
 (0)