Skip to content

Commit adfe6e4

Browse files
committed
add freeze/thaw
1 parent 9c12e5d commit adfe6e4

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ Optimisers.setup
3535
Optimisers.update
3636
Optimisers.update!
3737
Optimisers.adjust(::Any, ::Real)
38+
Optimisers.freeze!
39+
Optimisers.thaw!
3840
```
3941

4042
Calling `Functors.@functor` on your model's layer types by default causes

src/adjust.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,56 @@
1+
###
2+
### freeze!
3+
###
4+
5+
"""
6+
Optimisers.freeze!(tree)
7+
8+
Temporarily alters the state `tree = setup(rule, model)` so that parameters will not be updated.
9+
Can be applied to the state corresponding to only part of a model, for instance `model.layers[1]`.
10+
Un-done by [`thaw!`](@ref Optimisers.thaw).
11+
12+
# Example
13+
```jldoctest
14+
julia> m = (x = ([1.0], 2.0), y = [3.0]);
15+
16+
julia> s = Optimisers.setup(Momentum(), m);
17+
18+
julia> Optimisers.freeze!(s.x)
19+
20+
julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient
21+
22+
julia> m
23+
(x = ([1.0], 2.0), y = [-0.14159258336972558])
24+
25+
julia> s # Leaf(..., true) means frozen
26+
(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159]))
27+
28+
julia> Optimisers.thaw!(s)
29+
30+
julia> s.x[1]
31+
Leaf(Momentum{Float32}(0.01, 0.9), [0.0])
32+
```
33+
"""
34+
freeze!(tree) = (fmapstructure(freeze!, tree; exclude = x -> x isa Leaf); nothing)
35+
freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing)
36+
37+
"""
38+
Optimisers.thaw!(tree)
39+
40+
Un-does [`freeze!`](@ref Optimisers.freeze!) for all parameters,
41+
mutating every `Leaf(rule, state, true)` to `Leaf(rule, state, false)`.
42+
"""
43+
thaw!(tree) = (fmapstructure(thaw!, tree; exclude = x -> x isa Leaf); nothing)
44+
thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing)
45+
46+
freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError(
47+
"`freeze!` must not be applied to a model, only to the state tree from `setup`"))
48+
thaw!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError(
49+
"`thaw!` must not be applied to a model, only to the state tree from `setup`"))
50+
51+
###
52+
### adjust
53+
###
154

255
"""
356
Optimisers.adjust(tree, η) -> tree

src/interface.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ abstract type AbstractRule end
1010
### setup
1111
###
1212

13-
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
13+
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing...
1414
rule::R
1515
state::S
16+
frozen::Bool # ... and to allow freeze! to act on this.
1617
end
18+
Leaf(rule, state) = Leaf(rule, state, false)
1719

1820
@functor Leaf
1921

@@ -46,6 +48,7 @@ function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long
4648
ioc = IOContext(io, :compact => true)
4749
print(ioc, "Leaf(", ℓ.rule, ", ")
4850
show(ioc, ℓ.state)
51+
.frozen && print(ioc, ", true")
4952
print(ioc, ")")
5053
end
5154

@@ -83,6 +86,7 @@ function _update!(tree, x; grads, params)
8386
end
8487
function _update!(ℓ::Leaf, x; grads, params)
8588
haskey(params, (ℓ,x)) && return params[(ℓ,x)]
89+
.frozen && return x
8690
params[(ℓ,x)] = if haskey(grads, ℓ)
8791
.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
8892
subtract!(x, x̄′)

0 commit comments

Comments
 (0)