|
| 1 | +### |
| 2 | +### freeze! |
| 3 | +### |
| 4 | + |
| 5 | +""" |
| 6 | + Optimisers.freeze!(tree) |
| 7 | +
|
| 8 | +Temporarily alters the state `tree = setup(rule, model)` so that parameters will not be updated. |
| 9 | +Can be applied to the state corresponding to only part of a model, for instance `model.layers[1]`. |
| 10 | +Un-done by [`thaw!`](@ref Optimisers.thaw). |
| 11 | +
|
| 12 | +# Example |
| 13 | +```jldoctest |
| 14 | +julia> m = (x = ([1.0], 2.0), y = [3.0]); |
| 15 | +
|
| 16 | +julia> s = Optimisers.setup(Momentum(), m); |
| 17 | +
|
| 18 | +julia> Optimisers.freeze!(s.x) |
| 19 | +
|
| 20 | +julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient |
| 21 | +
|
| 22 | +julia> m |
| 23 | +(x = ([1.0], 2.0), y = [-0.14159258336972558]) |
| 24 | +
|
| 25 | +julia> s # Leaf(..., true) means frozen |
| 26 | +(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159])) |
| 27 | +
|
| 28 | +julia> Optimisers.thaw!(s) |
| 29 | +
|
| 30 | +julia> s.x[1] |
| 31 | +Leaf(Momentum{Float32}(0.01, 0.9), [0.0]) |
| 32 | +``` |
| 33 | +""" |
| 34 | +freeze!(tree) = (fmapstructure(freeze!, tree; exclude = x -> x isa Leaf); nothing) |
| 35 | +freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing) |
| 36 | + |
| 37 | +""" |
| 38 | + Optimisers.thaw!(tree) |
| 39 | +
|
| 40 | +Un-does [`freeze!`](@ref Optimisers.freeze!) for all parameters, |
| 41 | +mutating every `Leaf(rule, state, true)` to `Leaf(rule, state, false)`. |
| 42 | +""" |
| 43 | +thaw!(tree) = (fmapstructure(thaw!, tree; exclude = x -> x isa Leaf); nothing) |
| 44 | +thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing) |
| 45 | + |
| 46 | +freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( |
| 47 | + "`freeze!` must not be applied to a model, only to the state tree from `setup`")) |
| 48 | +thaw!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( |
| 49 | + "`thaw!` must not be applied to a model, only to the state tree from `setup`")) |
| 50 | + |
| 51 | +### |
| 52 | +### adjust |
| 53 | +### |
1 | 54 |
|
2 | 55 | """
|
3 | 56 | Optimisers.adjust(tree, η) -> tree
|
|
0 commit comments