67
67
# ##
68
68
69
69
"""
70
- Optimisers.setup(rule, model) -> tree
70
+ Optimisers.setup(rule, model) -> state_tree
71
71
72
72
Initialises the given optimiser for every trainable parameter within the model.
73
73
Returns a tree of the relevant states, which must be passed to [`update`](@ref)
@@ -142,6 +142,7 @@ This is used in exactly the same manner as [`update`](@ref), but because it may
142
142
arrays within the old model (and the old state), it will be faster for models of ordinary
143
143
`Array`s or `CuArray`s. However, you should not rely on the old model being fully updated
144
144
but rather use the returned model.
145
+ (The original state tree is always mutated, as each `Leaf` is mutable.)
145
146
146
147
# Example
147
148
@@ -150,9 +151,10 @@ julia> using StaticArrays, Zygote, Optimisers
150
151
151
152
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
152
153
153
- julia> t = Optimisers.setup(Momentum(1/30, 0.9), m);
154
+ julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
155
+ (x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
154
156
155
- julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]
157
+ julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
156
158
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
157
159
158
160
julia> t2, m2 = Optimisers.update!(t, m, g);
166
168
julia> m # original should be discarded, may be mutated but no guarantee
167
169
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
168
170
169
- julia> t == t2 # original state is in fact guaranteed to be mutated
171
+ julia> t == t2 # original state tree is guaranteed to be mutated
170
172
true
171
173
```
172
174
"""
0 commit comments