@@ -63,8 +63,34 @@ or [`update!`](@ref).
63
63
64
64
# Example
65
65
```jldoctest
66
- julia> Optimisers.setup(Descent(0.1f0), (x = rand(3), y = (true, false), z = tanh))
67
- (x = Leaf(Descent{Float32}(0.1), nothing), y = (nothing, nothing), z = nothing)
66
+ julia> m = (x = rand(3), y = (true, false), z = tanh);
67
+
68
+ julia> Optimisers.setup(Momentum(), m) # same field names as m
69
+ (x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = (nothing, nothing), z = nothing)
70
+ ```
71
+
72
+ The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
73
+ need to be marked with `Functors.@functor` before use.
74
+ See [the Flux docs](https://fluxml.ai/Flux.jl/stable/models/advanced/) for more about this.
75
+
76
+ ```
77
+ julia> struct Layer; mat; fun; end
78
+
79
+ julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
80
+
81
+ julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
82
+ (lay = nothing, vec = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0]))
83
+
84
+ julia> destructure(model)
85
+ (Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
86
+
87
+ julia> using Functors; @functor Layer # annotate this type as containing parameters
88
+
89
+ julia> Optimisers.setup(Momentum(), model)
90
+ (lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), [0.0 0.0; 0.0 0.0]), fun = nothing), vec = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0]))
91
+
92
+ julia> destructure(model)
93
+ (Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
68
94
```
69
95
"""
70
96
setup
0 commit comments