1
1
# Optimisers.jl
2
2
3
- ## Defining an Optimiser
3
+ ## Defining an optimisation rule
4
4
5
- A new optimiser must overload two functions, ` apply! ` and ` init ` :
5
+ A new optimiser must overload two functions, [ ` apply! ` ] ( @ref ) and [ ` init ` ] ( @ref ) .
6
+ These act on one array of parameters:
6
7
7
8
``` julia
8
9
# Define a container to hold any optimiser specific parameters (if any):
@@ -27,13 +28,12 @@ caried to the next iteration.
27
28
28
29
Notice that the state is handled separately from the optimiser itself. This
29
30
is a key design principle and allows users to manage their own state explicitly.
30
-
31
31
It of course also makes it easier to store the state.
32
32
33
33
## Usage with [ Flux.jl] ( https://github.com/FluxML/Flux.jl )
34
34
35
- To apply such an optimiser to a whole model, ` setup ` builds a tree containing any initial
36
- state for every trainable array. Then at each step, ` update ` uses this and the gradient
35
+ To apply such an optimiser to a whole model, [ ` setup ` ] ( @ref ) builds a tree containing any initial
36
+ state for every trainable array. Then at each step, [ ` update ` ] ( @ref ) uses this and the gradient
37
37
to adjust the model:
38
38
39
39
``` julia
@@ -67,7 +67,7 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
67
67
Zygote's "implicit" mode ` gradient(() -> loss(...), Flux.params(model)) ` -- see
68
68
[ Zygote's documentation] ( https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1 ) for more about this difference.
69
69
70
- There is also ` Optimisers.update! ` which similarly returns a new model and new state,
70
+ There is also [ ` Optimisers.update! ` ] ( @ref ) which similarly returns a new model and new state,
71
71
but is free to mutate arrays within the old one for efficiency.
72
72
The method of ` apply! ` you write is likewise free to mutate arrays within its state;
73
73
they are defensively copied when this rule is used with ` update ` .
@@ -110,3 +110,56 @@ Besides the parameters stored in `params` and gradually optimised, any other mod
110
110
is stored in ` lux_state ` . For simplicity this example does not show how to propagate the
111
111
updated ` lux_state ` to the next iteration, see Lux's documentation.
112
112
113
+ ## Obtaining a flat parameter vector
114
+
115
+ Instead of a nested tree-like structure, sometimes is is convenient to have all the
116
+ parameters as one simple vector. Optimisers.jl contains a function [ ` destructure ` ] ( @ref )
117
+ which creates this vector, and also creates way to re-build the original structure
118
+ with new parameters. Both flattening and re-building may be used within ` gradient ` calls.
119
+
120
+ An example with Flux's ` model ` :
121
+
122
+ ``` julia
123
+ using ForwardDiff # an example of a package which only likes one array
124
+
125
+ model = Chain ( # much smaller model example, as ForwardDiff is a slow algorithm here
126
+ Conv ((3 , 3 ), 3 => 5 , pad= 1 , bias= false ),
127
+ BatchNorm (5 , relu),
128
+ Conv ((3 , 3 ), 5 => 3 , stride= 16 ),
129
+ )
130
+ image = rand (Float32, 224 , 224 , 3 , 1 );
131
+ @show sum (model (image));
132
+
133
+ flat, re = destructure (model)
134
+ st = Optimisers. setup (rule, flat) # state is just one Leaf now
135
+
136
+ ∇flat = ForwardDiff. gradient (flat) do v
137
+ m = re (v) # rebuild a new object like model
138
+ sum (m (image)) # call that as before
139
+ end
140
+
141
+ st, flat = Optimisers. update (st, flat, ∇flat)
142
+ @show sum (re (flat)(image));
143
+ ```
144
+
145
+ Here ` flat ` contains only the 283 trainable parameters, while the non-trainable
146
+ ones are preserved inside ` re ` .
147
+ When defining new layers, these can be specified if necessary by overloading [ ` trainable ` ] ( @ref ) .
148
+ By default, all numeric arrays visible to [ Functors.jl] ( https://github.com/FluxML/Functors.jl )
149
+ are assumed to contain trainable parameters.
150
+
151
+ Lux stores only the trainable parameters in ` params ` .
152
+ This can also be flattened to a plain ` Vector ` in the same way:
153
+
154
+ ``` julia
155
+ params, lux_state = Lux. setup (Random. default_rng (), lux_model);
156
+
157
+ flat, re = destructure (params)
158
+
159
+ ∇flat = ForwardDiff. gradient (flat) do v
160
+ p = re (v) # rebuild an object like params
161
+ y, _ = Lux. apply (lux_model, images, p, lux_state)
162
+ sum (y)
163
+ end
164
+ ```
165
+
0 commit comments