66
66
# ##
67
67
68
68
"""
69
- Optimisers.setup(rule, model) -> tree
69
+ Optimisers.setup(rule, model) -> state_tree
70
70
71
71
Initialises the given optimiser for every trainable parameter within the model.
72
72
Returns a tree of the relevant states, which must be passed to [`update`](@ref)
@@ -141,6 +141,7 @@ This is used in exactly the same manner as [`update`](@ref), but because it may
141
141
arrays within the old model (and the old state), it will be faster for models of ordinary
142
142
`Array`s or `CuArray`s. However, you should not rely on the old model being fully updated
143
143
but rather use the returned model.
144
+ (The original state tree is always mutated, as each `Leaf` is mutable.)
144
145
145
146
# Example
146
147
@@ -149,9 +150,10 @@ julia> using StaticArrays, Zygote, Optimisers
149
150
150
151
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
151
152
152
- julia> t = Optimisers.setup(Momentum(1/30, 0.9), m);
153
+ julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
154
+ (x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
153
155
154
- julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]
156
+ julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
155
157
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
156
158
157
159
julia> t2, m2 = Optimisers.update!(t, m, g);
165
167
julia> m # original should be discarded, may be mutated but no guarantee
166
168
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
167
169
168
- julia> t == t2 # original state is in fact guaranteed to be mutated
170
+ julia> t == t2 # original state tree is guaranteed to be mutated
169
171
true
170
172
```
171
173
"""
0 commit comments