Skip to content

Commit 9a87a88

Browse files
authored
Merge pull request #136 from mcabbott/doc_april
Don't use `state` anywhere for the whole state tree
2 parents 54c3330 + f3c014c commit 9a87a88

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

docs/src/index.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data
4545
@show sum(model(image)); # dummy loss function
4646

4747
rule = Optimisers.Adam() # use the Adam optimiser with its default settings
48-
state = Optimisers.setup(rule, model); # initialise this optimiser's momentum etc.
48+
state_tree = Optimisers.setup(rule, model); # initialise this optimiser's momentum etc.
4949

5050
∇model, _ = gradient(model, image) do m, x # calculate the gradients
5151
sum(m(x))
5252
end;
5353

54-
state, model = Optimisers.update(state, model, ∇model);
54+
state_tree, model = Optimisers.update(state_tree, model, ∇model);
5555
@show sum(model(image)); # reduced
5656

5757
```
@@ -60,7 +60,7 @@ Notice that a completely new instance of the model is returned. Internally, this
6060
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
6161
tree formed by the model and update the parameters using the gradients.
6262

63-
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
63+
There is also [`Optimisers.update!`](@ref) which similarly returns a new model,
6464
but is free to mutate arrays within the old one for efficiency.
6565
(The method of `apply!` above is likewise free to mutate arrays within its state;
6666
they are defensively copied when this rule is used with `update`.)

src/Optimisers.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ init
6767
###
6868

6969
"""
70-
Optimisers.setup(rule, model) -> tree
70+
Optimisers.setup(rule, model) -> state_tree
7171
7272
Initialises the given optimiser for every trainable parameter within the model.
7373
Returns a tree of the relevant states, which must be passed to [`update`](@ref)
@@ -142,6 +142,7 @@ This is used in exactly the same manner as [`update`](@ref), but because it may
142142
arrays within the old model (and the old state), it will be faster for models of ordinary
143143
`Array`s or `CuArray`s. However, you should not rely on the old model being fully updated
144144
but rather use the returned model.
145+
(The original state tree is always mutated, as each `Leaf` is mutable.)
145146
146147
# Example
147148
@@ -150,9 +151,10 @@ julia> using StaticArrays, Zygote, Optimisers
150151
151152
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
152153
153-
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m);
154+
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
155+
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
154156
155-
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]
157+
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
156158
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
157159
158160
julia> t2, m2 = Optimisers.update!(t, m, g);
@@ -166,7 +168,7 @@ true
166168
julia> m # original should be discarded, may be mutated but no guarantee
167169
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
168170
169-
julia> t == t2 # original state is in fact guaranteed to be mutated
171+
julia> t == t2 # original state tree is guaranteed to be mutated
170172
true
171173
```
172174
"""

0 commit comments

Comments
 (0)