11# Optimisers.jl
22
3- ## Define an Optimiser
3+ ## Defining an Optimiser
44
55A new optimiser must overload two functions, ` apply! ` and ` init ` :
66
@@ -30,7 +30,7 @@ is a key design principle and allows users to manage their own state explicitly.
3030
3131It of course also makes it easier to store the state.
3232
33- ## Usage
33+ ## Usage with [ Flux.jl ] ( https://github.com/FluxML/Flux.jl )
3434
3535To apply such an optimiser to a whole model, ` setup ` builds a tree containing any initial
3636state for every trainable array. Then at each step, ` update ` uses this and the gradient
@@ -40,29 +40,73 @@ to adjust the model:
4040
4141using Flux, Metalhead, Optimisers
4242
43- model = Metalhead. ResNet18 () # define a model to train on
44- image = rand (Float32, 224 , 224 , 3 , 1 ); # dummy data
45- @show sum (model (image)); # dummy loss function
43+ model = Metalhead. ResNet ( 18 ) |> gpu # define a model to train
44+ image = rand (Float32, 224 , 224 , 3 , 1 ) |> gpu; # dummy data
45+ @show sum (model (image)); # dummy loss function
4646
47- o = Optimisers. ADAM () # define an ADAM optimiser with default settings
48- st = Optimisers. setup (o , model); # initialize the optimiser before using it
47+ rule = Optimisers. Adam () # use the Adam optimiser with its default settings
48+ state = Optimisers. setup (rule , model); # initialise this optimiser's momentum etc.
4949
50- m̄ , _ = gradient (model, image) do m, x # calculate the gradients
50+ ∇model , _ = gradient (model, image) do m, x # calculate the gradients
5151 sum (m (x))
5252end ;
5353
54- st , model = Optimisers. update (st , model, m̄ );
54+ state , model = Optimisers. update (state , model, ∇model );
5555@show sum (model (image));
5656
5757```
5858
5959Notice that a completely new instance of the model is returned. Internally, this
6060is handled by [ Functors.jl] ( https://fluxml.ai/Functors.jl ) , where we do a walk over the
61- tree formed by the model and update the parameters using the gradients. Optimisers can
62- work with different forms of gradients, but most likely use case are the gradients as
63- returned by [ Zygote.jl] ( https://fluxml.ai/Zygote.jl ) .
61+ tree formed by the model and update the parameters using the gradients.
62+
63+ Optimisers.jl does not depend on any one automatic differentiation package,
64+ but for now the most likely source of gradients is [ Zygote.jl] ( https://fluxml.ai/Zygote.jl ) .
65+ Note that ` update ` always wants the gradient from Zygote's "explicit" mode, as shown above.
66+ This ` ∇model ` is another tree structure, rather than the dictionary-like object from
67+ Zygote's "implicit" mode ` gradient(() -> loss(...), Flux.params(model)) ` -- see
68+ [ Zygote's documentation] ( https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1 ) for more about this difference.
6469
6570There is also ` Optimisers.update! ` which similarly returns a new model and new state,
6671but is free to mutate arrays within the old one for efficiency.
6772The method of ` apply! ` you write is likewise free to mutate arrays within its state;
6873they are defensively copied when this rule is used with ` update ` .
74+
75+ ## Usage with [ Lux.jl] ( https://github.com/avik-pal/Lux.jl )
76+
77+ The main design difference of Lux is that the tree of parameters is separate from
78+ the layer structure. It is these parameters which ` setup ` and ` update ` need to know about.
79+
80+ Lux describes this separation of parameter storage from model description as "explicit" parameters.
81+ Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
82+ (If the same model is written in Flux and Lux, ` ∇model ` above and ` ∇params ` below will often be
83+ identical trees of nested ` NamedTuple ` s.)
84+
85+ ``` julia
86+
87+ using Lux, Boltz, Zygote, Optimisers
88+
89+ lux_model, params, lux_state = Boltz. resnet (:resnet18 ) |> gpu; # define and initialise model
90+ images = rand (Float32, 224 , 224 , 3 , 4 ) |> gpu; # batch of dummy data
91+ y, _ = Lux. apply (lux_model, images, params, lux_state); # run the model
92+ @show sum (y) # initial dummy loss
93+
94+ rule = Optimisers. Adam ()
95+ opt_state = Optimisers. setup (rule, params); # optimiser state based on model parameters
96+
97+ ∇params, _ = gradient (params, images) do p, x # gradient with respect to parameter tree
98+ y, _ = Lux. apply (lux_model, x, p, lux_state)
99+ sum (y)
100+ end ;
101+
102+ opt_state, params = Optimisers. update! (opt_state, params, ∇params);
103+
104+ y, _ = Lux. apply (lux_model, images, params, lux_state);
105+ @show sum (y)
106+
107+ ```
108+
109+ Besides the parameters stored in ` params ` and gradually optimised, any other model state
110+ is stored in ` lux_state ` . For simplicity this example does not show how to propagate the
111+ updated ` lux_state ` to the next iteration, see Lux's documentation.
112+
0 commit comments