Skip to content

Commit c08fc58

Browse files
add support for dicts (#122)
* add support for dicts * aaa * fix update * cleanup * more tests * Update src/interface.jl Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Brian Chen <[email protected]>
1 parent 79269be commit c08fc58

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

src/backup.jl

Whitespace-only changes.

src/interface.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ function _setup(rule, x; cache)
4040
cache[x] =
4141
end
4242
else
43-
map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
43+
valuemap(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
4444
end
4545
end
4646

@@ -77,7 +77,7 @@ function _update!(tree, x; grads, params)
7777
haskey(params, (tree,x)) && return params[(tree,x)]
7878
isbits(tree) && return x # means () is not cached, and also (((),),)
7979
x′, re = functor(x)
80-
x′′ = re(map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
80+
x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
8181
if ismutable(x′′)
8282
params[(tree,x)] = x′′
8383
else # no ties to preserve between immutable structs, right?
@@ -109,7 +109,7 @@ function _grads!(dict::IdDict, tree, x, x̄s...)
109109
# functor(typeof(tree), base(x̄)), for things like Transpose
110110
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
111111
x′, _ = functor(typeof(x), x)
112-
foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
112+
valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
113113
end
114114

115115
# default all rules to first order calls
@@ -160,11 +160,22 @@ _trainable(x) = _trainable(functor(x)[1], trainable(x))
160160
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
161161
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
162162
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
163+
_trainable(ch::Dict, tr::Dict) = merge(valuemap(_ -> nothing, ch), tr)
164+
163165
function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
164166
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3
165167
map(c -> c in tr ? c : nothing, ch)
166168
end
167169

170+
171+
valuemap(f, x...) = map(f, x...)
172+
valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
173+
valueforeach(f, x...) = foreach(f, x...)
174+
valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
175+
f(v, (get(y, k, nothing) for y in ys)...)
176+
end
177+
178+
168179
###
169180
### rule definition helpers
170181
###

test/runtests.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,36 @@ y2z(x) = x
9999
@test isnan(m3n.γ[3])
100100
end
101101

102+
@testset "Dict support" begin
103+
@testset "simple dict" begin
104+
d = Dict(:a => [1.0,2.0], :b => [3.0,4.0], :c => 1)
105+
s = Optimisers.setup(AdamW(0.1), d)
106+
@test s isa Dict{Symbol, <:Any}
107+
@test s[:a] isa Optimisers.Leaf
108+
@test s[:b] isa Optimisers.Leaf
109+
@test s[:c] === ()
110+
loss(model) = sum(abs2, model[:a])
111+
g = gradient(loss, d)[1]
112+
s2, d2 = Optimisers.update(s, d, g)
113+
@test s2 isa Dict{Symbol, <:Any}
114+
@test d2 isa Dict{Symbol, <:Any}
115+
@test d2[:a] == [0.9, 1.9]
116+
@test d2[:b] == [3, 4]
117+
@test d2[:c] == 1
118+
end
119+
120+
@testset "nested dict" begin
121+
d = Dict(1 => [1.0,2.0], 2 => Dict("a" => (; c=[3.0,4.0]), "b" => 1))
122+
s = Optimisers.setup(AdamW(0.1), d)
123+
@test s[2]["a"].c isa Optimisers.Leaf
124+
g = gradient(d -> sum(d[2]["a"].c), d)[1]
125+
s2, d2 = Optimisers.update(s, d, g)
126+
@test d2[2]["a"].c == [2.9, 3.9]
127+
@test d2[1] == [1, 2]
128+
@test d2[2]["b"] == 1
129+
end
130+
end
131+
102132
@testset "OptimiserChain" begin
103133
x = [1, 10, 100.0]; dx = [1, 2, 3.0];
104134
@test Optimisers.update(Optimisers.setup(WeightDecay(0.1), x), x, dx)[2] [1-0.1-1, 10-1-2, 100-10-3]

0 commit comments

Comments
 (0)