Skip to content

Commit acd5ca3

Browse files
mcabbottToucheSir
andauthored
Add doc section about destructure (#82)
* Add doc section about destructure * more links * Update docs/src/index.md Co-authored-by: Brian Chen <[email protected]> * shorten Lux example Co-authored-by: Brian Chen <[email protected]>
1 parent 4e53612 commit acd5ca3

File tree

1 file changed

+59
-6
lines changed

1 file changed

+59
-6
lines changed

docs/src/index.md

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Optimisers.jl
22

3-
## Defining an Optimiser
3+
## Defining an optimisation rule
44

5-
A new optimiser must overload two functions, `apply!` and `init`:
5+
A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref).
6+
These act on one array of parameters:
67

78
```julia
89
# Define a container to hold any optimiser specific parameters (if any):
@@ -27,13 +28,12 @@ caried to the next iteration.
2728

2829
Notice that the state is handled separately from the optimiser itself. This
2930
is a key design principle and allows users to manage their own state explicitly.
30-
3131
It of course also makes it easier to store the state.
3232

3333
## Usage with [Flux.jl](https://github.com/FluxML/Flux.jl)
3434

35-
To apply such an optimiser to a whole model, `setup` builds a tree containing any initial
36-
state for every trainable array. Then at each step, `update` uses this and the gradient
35+
To apply such an optimiser to a whole model, [`setup`](@ref) builds a tree containing any initial
36+
state for every trainable array. Then at each step, [`update`](@ref) uses this and the gradient
3737
to adjust the model:
3838

3939
```julia
@@ -67,7 +67,7 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
6767
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
6868
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
6969

70-
There is also `Optimisers.update!` which similarly returns a new model and new state,
70+
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
7171
but is free to mutate arrays within the old one for efficiency.
7272
The method of `apply!` you write is likewise free to mutate arrays within its state;
7373
they are defensively copied when this rule is used with `update`.
@@ -110,3 +110,56 @@ Besides the parameters stored in `params` and gradually optimised, any other mod
110110
is stored in `lux_state`. For simplicity this example does not show how to propagate the
111111
updated `lux_state` to the next iteration, see Lux's documentation.
112112

113+
## Obtaining a flat parameter vector
114+
115+
Instead of a nested tree-like structure, sometimes is is convenient to have all the
116+
parameters as one simple vector. Optimisers.jl contains a function [`destructure`](@ref)
117+
which creates this vector, and also creates way to re-build the original structure
118+
with new parameters. Both flattening and re-building may be used within `gradient` calls.
119+
120+
An example with Flux's `model`:
121+
122+
```julia
123+
using ForwardDiff # an example of a package which only likes one array
124+
125+
model = Chain( # much smaller model example, as ForwardDiff is a slow algorithm here
126+
Conv((3, 3), 3 => 5, pad=1, bias=false),
127+
BatchNorm(5, relu),
128+
Conv((3, 3), 5 => 3, stride=16),
129+
)
130+
image = rand(Float32, 224, 224, 3, 1);
131+
@show sum(model(image));
132+
133+
flat, re = destructure(model)
134+
st = Optimisers.setup(rule, flat) # state is just one Leaf now
135+
136+
∇flat = ForwardDiff.gradient(flat) do v
137+
m = re(v) # rebuild a new object like model
138+
sum(m(image)) # call that as before
139+
end
140+
141+
st, flat = Optimisers.update(st, flat, ∇flat)
142+
@show sum(re(flat)(image));
143+
```
144+
145+
Here `flat` contains only the 283 trainable parameters, while the non-trainable
146+
ones are preserved inside `re`.
147+
When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref).
148+
By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl)
149+
are assumed to contain trainable parameters.
150+
151+
Lux stores only the trainable parameters in `params`.
152+
This can also be flattened to a plain `Vector` in the same way:
153+
154+
```julia
155+
params, lux_state = Lux.setup(Random.default_rng(), lux_model);
156+
157+
flat, re = destructure(params)
158+
159+
∇flat = ForwardDiff.gradient(flat) do v
160+
p = re(v) # rebuild an object like params
161+
y, _ = Lux.apply(lux_model, images, p, lux_state)
162+
sum(y)
163+
end
164+
```
165+

0 commit comments

Comments
 (0)