11# Optimisers.jl
22
3+ Optimisers.jl defines many standard gradient-based optimisation rules, and tools for applying them to deeply nested models.
4+
5+ This was written as the new training system for [ Flux.jl] ( https://github.com/FluxML/Flux.jl ) neural networks,
6+ and also used by [ Lux.jl] ( https://github.com/LuxDL/Lux.jl ) .
7+ But it can be used separately on any array, or anything else understood by [ Functors.jl] ( https://github.com/FluxML/Functors.jl ) .
8+
9+ ## Installation
10+
11+ In the Julia REPL, type
12+ ``` julia
13+ ]add Optimisers
14+ ```
15+
16+ or
17+ ``` julia-repl
18+ julia> import Pkg; Pkg.add("Optimisers")
19+ ```
20+
321## An optimisation rule
422
523A new optimiser must overload two functions, [ ` apply! ` ] (@ref Optimisers.apply!) and [ ` init ` ] (@ref Optimisers.init).
@@ -38,7 +56,6 @@ state for every trainable array. Then at each step, [`update`](@ref Optimisers.u
3856to adjust the model:
3957
4058``` julia
41-
4259using Flux, Metalhead, Zygote, Optimisers
4360
4461model = Metalhead. ResNet (18 ) |> gpu # define a model to train
5471
5572state_tree, model = Optimisers. update (state_tree, model, ∇model);
5673@show sum (model (image)); # reduced
57-
5874```
5975
6076Notice that a completely new instance of the model is returned. Internally, this
@@ -91,7 +107,6 @@ Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
91107identical trees of nested ` NamedTuple ` s.)
92108
93109``` julia
94-
95110using Lux, Boltz, Zygote, Optimisers
96111
97112lux_model, params, lux_state = Boltz. resnet (:resnet18 ) |> gpu; # define and initialise model
@@ -113,7 +128,6 @@ opt_state, params = Optimisers.update!(opt_state, params, ∇params);
113128
114129y, lux_state = Lux. apply (lux_model, images, params, lux_state);
115130@show sum (y); # now reduced
116-
117131```
118132
119133Besides the parameters stored in ` params ` and gradually optimised, any other model state
@@ -297,7 +311,7 @@ similarly to what [`destructure`](@ref Optimisers.destructure) does but without
297311concatenating the arrays into a flat vector.
298312This is done by [ ` trainables ` ] (@ref Optimisers.trainables), which returns a list of arrays:
299313
300- ``` julia
314+ ``` julia-repl
301315julia> using Flux, Optimisers
302316
303317julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));
0 commit comments