Skip to content

Commit 5b6b380

Browse files
authored
minimal trainable (#36)
1 parent 132bdd8 commit 5b6b380

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ Optimisers.update
3535
Optimisers.update!
3636
```
3737

38+
Calling `Functors.@functor` on your model's layer types by default causes the
39+
optimiser to act on all suitable fields. To restrict this, define `trainable`:
40+
41+
```@docs
42+
Optimisers.trainable
43+
```
44+
3845
## Rule Definition
3946

4047
```@docs

src/interface.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@ function setup(rule, x; seen = Base.IdSet())
1212
elseif isleaf(x)
1313
return nothing
1414
else
15-
x′, _ = functor(x)
16-
return map(xᵢ -> setup(rule, xᵢ; seen), x′)
15+
return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x))
1716
end
1817
end
1918

2019
subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : (x .- x̄)
2120

21+
update!(::Nothing, x, x̄s...) = nothing, x
22+
2223
function update!(ℓ::Leaf, x, x̄s...)
2324
if all(isnothing, x̄s)
2425
return ℓ, x
@@ -55,6 +56,26 @@ isnumeric(x) = false
5556
iswriteable(::DenseArray{<:AbstractFloat}) = true # more elaborate versions are possible, wait until needed?
5657
iswriteable(_) = false
5758

59+
"""
60+
trainable(x::Layer) -> NamedTuple
61+
62+
This should be overloaded to make optimisers ignore some fields of
63+
every `Layer`, which would otherwise contain trainable parameters.
64+
(Elements such as functions and sizes are always ignored.)
65+
66+
The default is `Functors.children(x)`, usually a NamedTuple of all fields,
67+
and `trainable(x)` must contain a subset of these.
68+
"""
69+
trainable(x) = functor(x)[1]
70+
71+
_trainable(x) = _trainable(functor(x)[1], trainable(x))
72+
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
73+
_trainable(ch::Tuple, tr::Tuple) = tr
74+
function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
75+
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple"
76+
map(c -> c in tr ? c : nothing, ch)
77+
end
78+
5879
"""
5980
@.. x = x + y
6081
@.. x + y / z

test/runtests.jl

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1-
using Optimisers, Test
2-
using Zygote
3-
using Statistics, Random, LinearAlgebra
4-
Random.seed!(1)
1+
using Optimisers, Functors, Zygote
2+
using LinearAlgebra, Statistics, Test, Random
53
using Optimisers: @..
64

5+
Random.seed!(1)
6+
7+
struct Foo; x; y; end
8+
Functors.@functor Foo
9+
Optimisers.trainable(x::Foo) = (x.y, x.x)
10+
11+
struct TwoThirds a; b; c; end
12+
Functors.@functor TwoThirds (a, c)
13+
Optimisers.trainable(x::TwoThirds) = (a = x.a,)
14+
715
@testset verbose=true "Optimisers.jl" begin
816

917
@testset "very basics" begin
@@ -23,7 +31,7 @@ using Optimisers: @..
2331
@test m3[1] [1,2] .- 0.1 .* [25, 33]
2432
end
2533

26-
@testset "$(first(string(o), 42))" for o in (
34+
@testset "rule: $(first(string(o), 42))" for o in (
2735
Descent(), ADAM(), Momentum(), Nesterov(), RMSProp(),
2836
ADAGrad(), AdaMax(), ADADelta(), AMSGrad(), NADAM(),
2937
ADAMW(), RADAM(), OADAM(), AdaBelief()
@@ -99,6 +107,38 @@ using Optimisers: @..
99107
@test isnan(m3n.γ[3])
100108
end
101109

110+
@testset "trainable subset" begin
111+
# Foo has an old-style tuple trainable, both elements
112+
mf = Foo([1,2], (a = sin, b = [3,4], c = 5))
113+
sf = Optimisers.setup(Descent(0.1), mf)
114+
gf = (x = nothing, y = (a = nothing, b = [1,1], c = 1))
115+
_, mf2 = Optimisers.update(sf, mf, gf)
116+
@test mf2.x == [1,2]
117+
@test mf2.y == (a = sin, b = [2.9, 3.9], c = 5)
118+
119+
# TwoThirds has functor a,c only, and trainable a only
120+
mt = TwoThirds(Float32[1,2], Float32[3,4], Float32[5,6])
121+
mt10 = fmap(x -> 10x, mt)
122+
@test mt10.a == [10, 20]
123+
@test mt10.b == [3, 4]
124+
@test mt10.c == [50, 60]
125+
st = Optimisers.setup(Momentum(0.1, 0.9), mt)
126+
gt = gradient(m -> sum(abs2, m.a) + 100sum(abs2, m.b), mt)
127+
_, mtup = Optimisers.update(st, mt, gt...)
128+
@test mtup.a [0.8, 1.6]
129+
@test mtup.b == [3, 4]
130+
@test mtup.c == [5, 6]
131+
132+
# Various kinds of missing branches together:
133+
m = Foo(
134+
TwoThirds(Foo(1.0, Float32[2,3,4]), 5.0, Float32[6,7]),
135+
TwoThirds((p = Float32[1,2,3],), sin, (q = 4.0, r = cos,)),
136+
)
137+
s = Optimisers.setup(Momentum(0.1, 0.9), m)
138+
g = gradient(m -> sum(abs2, m.x.a.y) + m.x.b^2 + log(m.y.c.q), m)
139+
@test Optimisers.update!(s, m, g...)[2] isa Foo
140+
end
141+
102142
@testset "broadcasting macro" begin
103143
x = [1.0, 2.0]; y = [3,4]; z = [5,6]
104144
@test (@.. x + y * z) isa Broadcast.Broadcasted

0 commit comments

Comments
 (0)