Skip to content

Commit b81dabb

Browse files
Move usage and such into the docs
1 parent c960eda commit b81dabb

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

docs/src/index.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,54 @@
11
# Optimisers.jl
22

3+
## Define an Optimiser
4+
5+
```julia
6+
# Define a container to hold any optimiser specific parameters (if any)
7+
struct Descent{T}
8+
η::T
9+
end
10+
11+
# Define an `apply` rule with which to update the current params
12+
# using the gradients
13+
function Optimisers.apply(o::Descent, state, m, m̄)
14+
o.η .* m̄, state
15+
end
16+
17+
Optimisers.init(o, x::AbstractArray) = nothing
18+
```
19+
20+
Notice that the state is handled separately from the optimiser itself. This
21+
is a key design principle and allows users to manage their own state explicitly.
22+
23+
It of course also makes it easier to store the state.
24+
25+
## Usage
26+
27+
```julia
28+
29+
using Flux, Metalhead, Optimisers
30+
31+
o = Optimisers.ADAM() # define an ADAM optimiser with default settings
32+
st = Optimisers.state(o, m) # initialize the optimiser before using it
33+
34+
model = ResNet() # define a model to train on
35+
ip = rand(Float32, 224, 224, 3, 1) # dummy data
36+
37+
m̄, _ = gradient(model, ip) do m, x # calculate the gradients
38+
sum(m(x))
39+
end
40+
41+
42+
st, mnew = Optimisers.update(o, st, m, m̄)
43+
44+
# or
45+
46+
st, mnew = o(m, m̄, st)
47+
```
48+
49+
Notice that a completely new instance of the model is returned. Internally, this
50+
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
51+
tree formed by the model and update the parameters using the gradients. Optimisers can
52+
work with different forms of gradients, but most likely use case are the gradients as
53+
returned by [Zygote.jl](https://fluxml.ai/Zygote.jl).
354

0 commit comments

Comments
 (0)