Skip to content

Commit f3c014c

Browse files
committed
don't use state for tree
1 parent 14949f1 commit f3c014c

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

docs/src/index.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data
4545
@show sum(model(image)); # dummy loss function
4646

4747
rule = Optimisers.Adam() # use the Adam optimiser with its default settings
48-
state = Optimisers.setup(rule, model); # initialise this optimiser's momentum etc.
48+
state_tree = Optimisers.setup(rule, model); # initialise this optimiser's momentum etc.
4949

5050
∇model, _ = gradient(model, image) do m, x # calculate the gradients
5151
sum(m(x))
5252
end;
5353

54-
state, model = Optimisers.update(state, model, ∇model);
54+
state_tree, model = Optimisers.update(state_tree, model, ∇model);
5555
@show sum(model(image)); # reduced
5656

5757
```
@@ -60,7 +60,7 @@ Notice that a completely new instance of the model is returned. Internally, this
6060
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
6161
tree formed by the model and update the parameters using the gradients.
6262

63-
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
63+
There is also [`Optimisers.update!`](@ref) which similarly returns a new model,
6464
but is free to mutate arrays within the old one for efficiency.
6565
(The method of `apply!` above is likewise free to mutate arrays within its state;
6666
they are defensively copied when this rule is used with `update`.)

src/Optimisers.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ init
6666
###
6767

6868
"""
69-
Optimisers.setup(rule, model) -> tree
69+
Optimisers.setup(rule, model) -> state_tree
7070
7171
Initialises the given optimiser for every trainable parameter within the model.
7272
Returns a tree of the relevant states, which must be passed to [`update`](@ref)
@@ -141,6 +141,7 @@ This is used in exactly the same manner as [`update`](@ref), but because it may
141141
arrays within the old model (and the old state), it will be faster for models of ordinary
142142
`Array`s or `CuArray`s. However, you should not rely on the old model being fully updated
143143
but rather use the returned model.
144+
(The original state tree is always mutated, as each `Leaf` is mutable.)
144145
145146
# Example
146147
@@ -149,9 +150,10 @@ julia> using StaticArrays, Zygote, Optimisers
149150
150151
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
151152
152-
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m);
153+
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
154+
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
153155
154-
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]
156+
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
155157
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
156158
157159
julia> t2, m2 = Optimisers.update!(t, m, g);
@@ -165,7 +167,7 @@ true
165167
julia> m # original should be discarded, may be mutated but no guarantee
166168
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
167169
168-
julia> t == t2 # original state is in fact guaranteed to be mutated
170+
julia> t == t2 # original state tree is guaranteed to be mutated
169171
true
170172
```
171173
"""

0 commit comments

Comments
 (0)