Skip to content

Commit 34250b2

Browse files
authored
Update doc pages and docstrings (#200)
1 parent 5f2680a commit 34250b2

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

docs/src/index.md

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
# Optimisers.jl
22

3+
Optimisers.jl defines many standard gradient-based optimisation rules, and tools for applying them to deeply nested models.
4+
5+
This was written as the new training system for [Flux.jl](https://github.com/FluxML/Flux.jl) neural networks,
6+
and also used by [Lux.jl](https://github.com/LuxDL/Lux.jl).
7+
But it can be used separately on any array, or anything else understood by [Functors.jl](https://github.com/FluxML/Functors.jl).
8+
9+
## Installation
10+
11+
In the Julia REPL, type
12+
```julia
13+
]add Optimisers
14+
```
15+
16+
or
17+
```julia-repl
18+
julia> import Pkg; Pkg.add("Optimisers")
19+
```
20+
321
## An optimisation rule
422

523
A new optimiser must overload two functions, [`apply!`](@ref Optimisers.apply!) and [`init`](@ref Optimisers.init).
@@ -38,7 +56,6 @@ state for every trainable array. Then at each step, [`update`](@ref Optimisers.u
3856
to adjust the model:
3957

4058
```julia
41-
4259
using Flux, Metalhead, Zygote, Optimisers
4360

4461
model = Metalhead.ResNet(18) |> gpu # define a model to train
@@ -54,7 +71,6 @@ end;
5471

5572
state_tree, model = Optimisers.update(state_tree, model, ∇model);
5673
@show sum(model(image)); # reduced
57-
5874
```
5975

6076
Notice that a completely new instance of the model is returned. Internally, this
@@ -91,7 +107,6 @@ Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
91107
identical trees of nested `NamedTuple`s.)
92108

93109
```julia
94-
95110
using Lux, Boltz, Zygote, Optimisers
96111

97112
lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
@@ -113,7 +128,6 @@ opt_state, params = Optimisers.update!(opt_state, params, ∇params);
113128

114129
y, lux_state = Lux.apply(lux_model, images, params, lux_state);
115130
@show sum(y); # now reduced
116-
117131
```
118132

119133
Besides the parameters stored in `params` and gradually optimised, any other model state
@@ -297,7 +311,7 @@ similarly to what [`destructure`](@ref Optimisers.destructure) does but without
297311
concatenating the arrays into a flat vector.
298312
This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays:
299313

300-
```julia
314+
```julia-repl
301315
julia> using Flux, Optimisers
302316
303317
julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));

src/destructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ This is what [`destructure`](@ref Optimisers.destructure) returns, and `re(p)` w
3838
new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`.
3939
4040
# Example
41-
```julia
41+
```julia-repl
4242
julia> using Flux, Optimisers
4343
4444
julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))

src/trainables.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ julia> trainables(x)
3232
1-element Vector{AbstractArray}:
3333
[1.0, 2.0, 3.0]
3434
35-
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
35+
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
3636
37-
julia> trainables(x) # collects nested parameters
38-
2-element Vector{AbstractArray}:
37+
julia> trainables(x) # collects nested parameters
38+
2-element Vector{AbstractArray}:
3939
[1.0, 2.0]
4040
[3.0]
4141
```

0 commit comments

Comments
 (0)