Skip to content

Commit 4e53612

Browse files
mcabbottdarsnack
andauthored
Document the need for explicit gradients (#80)
* document use of explicit gradients * wording * Update docs/src/index.md Co-authored-by: Kyle Daruwalla <[email protected]> * wording * add a Usage with Lux.jl section too * further comment on model state * better notation * use the same resnet example for Lux * pipe to gpu * tweak resnet lines Co-authored-by: Kyle Daruwalla <[email protected]>
1 parent 7f26f7f commit 4e53612

File tree

2 files changed

+62
-14
lines changed

2 files changed

+62
-14
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Optimisers.jl defines many standard gradient-based optimisation rules, and tools
2121

2222
This is the future of training for [Flux.jl](https://github.com/FluxML/Flux.jl) neural networks,
2323
and the present for [Lux.jl](https://github.com/avik-pal/Lux.jl).
24-
But it can be used separately on anything understood by [Functors.jl](https://github.com/FluxML/Functors.jl).
24+
But it can be used separately on any array, or anything else understood by [Functors.jl](https://github.com/FluxML/Functors.jl).
2525

2626
## Installation
2727

@@ -38,11 +38,15 @@ state, and the model with its trainable parameters adjusted:
3838
```julia
3939
state = Optimisers.setup(Optimisers.Adam(), model) # just once
4040

41+
grad = Zygote.gradient(m -> loss(m(x), y), model)[1]
42+
4143
state, model = Optimisers.update(state, model, grad) # at every step
4244
```
4345

4446
For models with deeply nested layers containing the parameters (like [Flux.jl](https://github.com/FluxML/Flux.jl) models),
45-
this state is a similarly nested tree.
47+
this state is a similarly nested tree. As is the gradient: if using Zygote, you must use the "explicit" style as shown,
48+
not the "implicit" one with `Params`.
49+
4650
The function `destructure` collects all the trainable parameters into one vector,
4751
and returns this along with a function to re-build a similar model:
4852

docs/src/index.md

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Optimisers.jl
22

3-
## Define an Optimiser
3+
## Defining an Optimiser
44

55
A new optimiser must overload two functions, `apply!` and `init`:
66

@@ -30,7 +30,7 @@ is a key design principle and allows users to manage their own state explicitly.
3030

3131
It of course also makes it easier to store the state.
3232

33-
## Usage
33+
## Usage with [Flux.jl](https://github.com/FluxML/Flux.jl)
3434

3535
To apply such an optimiser to a whole model, `setup` builds a tree containing any initial
3636
state for every trainable array. Then at each step, `update` uses this and the gradient
@@ -40,29 +40,73 @@ to adjust the model:
4040

4141
using Flux, Metalhead, Optimisers
4242

43-
model = Metalhead.ResNet18() # define a model to train on
44-
image = rand(Float32, 224, 224, 3, 1); # dummy data
45-
@show sum(model(image)); # dummy loss function
43+
model = Metalhead.ResNet(18) |> gpu # define a model to train
44+
image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data
45+
@show sum(model(image)); # dummy loss function
4646

47-
o = Optimisers.ADAM() # define an ADAM optimiser with default settings
48-
st = Optimisers.setup(o, model); # initialize the optimiser before using it
47+
rule = Optimisers.Adam() # use the Adam optimiser with its default settings
48+
state = Optimisers.setup(rule, model); # initialise this optimiser's momentum etc.
4949

50-
, _ = gradient(model, image) do m, x # calculate the gradients
50+
∇model, _ = gradient(model, image) do m, x # calculate the gradients
5151
sum(m(x))
5252
end;
5353

54-
st, model = Optimisers.update(st, model, );
54+
state, model = Optimisers.update(state, model, ∇model);
5555
@show sum(model(image));
5656

5757
```
5858

5959
Notice that a completely new instance of the model is returned. Internally, this
6060
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
61-
tree formed by the model and update the parameters using the gradients. Optimisers can
62-
work with different forms of gradients, but most likely use case are the gradients as
63-
returned by [Zygote.jl](https://fluxml.ai/Zygote.jl).
61+
tree formed by the model and update the parameters using the gradients.
62+
63+
Optimisers.jl does not depend on any one automatic differentiation package,
64+
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
65+
Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above.
66+
This `∇model` is another tree structure, rather than the dictionary-like object from
67+
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
68+
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
6469

6570
There is also `Optimisers.update!` which similarly returns a new model and new state,
6671
but is free to mutate arrays within the old one for efficiency.
6772
The method of `apply!` you write is likewise free to mutate arrays within its state;
6873
they are defensively copied when this rule is used with `update`.
74+
75+
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
76+
77+
The main design difference of Lux is that the tree of parameters is separate from
78+
the layer structure. It is these parameters which `setup` and `update` need to know about.
79+
80+
Lux describes this separation of parameter storage from model description as "explicit" parameters.
81+
Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
82+
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will often be
83+
identical trees of nested `NamedTuple`s.)
84+
85+
```julia
86+
87+
using Lux, Boltz, Zygote, Optimisers
88+
89+
lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
90+
images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data
91+
y, _ = Lux.apply(lux_model, images, params, lux_state); # run the model
92+
@show sum(y) # initial dummy loss
93+
94+
rule = Optimisers.Adam()
95+
opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters
96+
97+
∇params, _ = gradient(params, images) do p, x # gradient with respect to parameter tree
98+
y, _ = Lux.apply(lux_model, x, p, lux_state)
99+
sum(y)
100+
end;
101+
102+
opt_state, params = Optimisers.update!(opt_state, params, ∇params);
103+
104+
y, _ = Lux.apply(lux_model, images, params, lux_state);
105+
@show sum(y)
106+
107+
```
108+
109+
Besides the parameters stored in `params` and gradually optimised, any other model state
110+
is stored in `lux_state`. For simplicity this example does not show how to propagate the
111+
updated `lux_state` to the next iteration, see Lux's documentation.
112+

0 commit comments

Comments
 (0)