Skip to content

Commit 2da6d7f

Browse files
docs for nothing behavior and for walking a tree with keypath (#191)
* cl/zero * Update docs/src/index.md Co-authored-by: Michael Abbott <[email protected]> --------- Co-authored-by: Michael Abbott <[email protected]>
1 parent e5d187c commit 2da6d7f

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

docs/src/index.md

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,50 @@ julia> trainables(model)
311311
Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252]
312312
Float32[0.0, 0.0]
313313

314-
julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]);
314+
julia> l2reg(model) = sum([sum(abs2, p) for p in trainables(model)]);
315315

316316
julia> g = gradient(l2reg, model)[1];
317317
```
318318
Notice that the `BatchNorm` layer has two trainable parameters, `γ` and `β`, which are included in the list, while the `μ ` and `σ²` buffers are not.
319+
320+
Sometimes one wants to iterate over all trainable parameters in a model and the corresponding parameters of a matched structure such a gradient or the moving average of the model.
321+
This can be done using `trainables(model, path=true)`. For instance, here is how to update the parameters
322+
of a moving average model with the parameters of the model:
323+
324+
```julia
325+
for (kp, p_avg) in trainables(model_avg, path=true)
326+
p = getkeypath(model, kp)
327+
p_avg .= 0.99 .* p_avg .+ 0.01 .* p
328+
end
329+
```
330+
331+
## Incomplete or nothing gradients
332+
333+
If the gradient is not available for some parameters, or branches of the model,
334+
`update` will not take an optimisation step for those parameters.
335+
This is the case when the gradient is `nothing` or a subtype of `ChainRules.AbstractZero`.
336+
337+
For stateful optimisers, skipping an update it is generaly not the same as updating with a zero gradient.
338+
For example, in the case of Adam, the momentum and variance are updated even if the gradient is zero:
339+
340+
```julia-repl
341+
julia> x = (a = ones(2), b = ones(2));
342+
(a = [1.0, 1.0], b = [1.0, 1.0])
343+
344+
julia> opt_state = Optimisers.setup(Adam(0.1), x)
345+
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.0, 0.0], [0.0, 0.0], (0.9, 0.999))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.0, 0.0], [0.0, 0.0], (0.9, 0.999))))
346+
347+
julia> g = (; a = ones(2), b = ones(2)); # First an update with a non-zero gradient to increase the momentum and variance
348+
349+
julia> Optimisers.update!(opt_state, x, g);
350+
351+
julia> opt_state # the state in `a` and `b` are the same
352+
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))))
353+
354+
julia> g = (; a = zeros(2), b = nothing); # Now an update with a zero gradient for a and no gradient for b
355+
356+
julia> Optimisers.update!(opt_state, x, g);
357+
358+
julia> opt_state # the state in `a` and `b` differ
359+
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.09, 0.09], [0.000999, 0.000999], (0.729, 0.997003))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))))
360+
```

src/interface.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,21 @@ end
103103
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
104104
subtract!(x, x̄::Zero) = x
105105

106+
# If we get Zero from AD on a leaf we skip the optimizer step. See
107+
# https://github.com/FluxML/Optimisers.jl/issues/140
106108
_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing
109+
107110
function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...)
108111
x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s))
109112
dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible.
110113
nothing
111114
end
115+
116+
# If we get Zero from AD in correspondence of a non-leaf node
117+
# we end the recursion. The optimizer step won't be taken.
118+
# https://github.com/FluxML/Optimisers.jl/issues/140
112119
_grads!(dict::IdDict, t, x, ::Zero...) = nothing
120+
113121
function _grads!(dict::IdDict, tree, x, x̄s...)
114122
# The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from
115123
# functor(typeof(tree), base(x̄)), for things like Transpose

0 commit comments

Comments
 (0)