@@ -40,7 +40,7 @@ function _setup(rule, x; cache)
40
40
cache[x] = ℓ
41
41
end
42
42
else
43
- map (xᵢ -> _setup (rule, xᵢ; cache), _trainable (x))
43
+ valuemap (xᵢ -> _setup (rule, xᵢ; cache), _trainable (x))
44
44
end
45
45
end
46
46
@@ -77,7 +77,7 @@ function _update!(tree, x; grads, params)
77
77
haskey (params, (tree,x)) && return params[(tree,x)]
78
78
isbits (tree) && return x # means () is not cached, and also (((),),)
79
79
x′, re = functor (x)
80
- x′′ = re (map ((tᵢ, xᵢ) -> _update! (tᵢ, xᵢ; grads, params), tree, x′))
80
+ x′′ = re (valuemap ((tᵢ, xᵢ) -> _update! (tᵢ, xᵢ; grads, params), tree, x′))
81
81
if ismutable (x′′)
82
82
params[(tree,x)] = x′′
83
83
else # no ties to preserve between immutable structs, right?
@@ -109,7 +109,7 @@ function _grads!(dict::IdDict, tree, x, x̄s...)
109
109
# functor(typeof(tree), base(x̄)), for things like Transpose
110
110
x̄s′ = map (x̄ -> functor (typeof (x), base (x̄))[1 ], x̄s)
111
111
x′, _ = functor (typeof (x), x)
112
- foreach ((tᵢ, xᵢ, x̄sᵢ... ) -> _grads! (dict, tᵢ, xᵢ, x̄sᵢ... ), tree, x′, x̄s′... )
112
+ valueforeach ((tᵢ, xᵢ, x̄sᵢ... ) -> _grads! (dict, tᵢ, xᵢ, x̄sᵢ... ), tree, x′, x̄s′... )
113
113
end
114
114
115
115
# default all rules to first order calls
@@ -160,11 +160,22 @@ _trainable(x) = _trainable(functor(x)[1], trainable(x))
160
160
_trainable (ch:: NamedTuple , tr:: NamedTuple ) = merge (map (_ -> nothing , ch), tr)
161
161
_trainable (ch:: Tuple{Vararg{Any,N}} , tr:: Tuple{Vararg{Any,N}} ) where N = tr
162
162
_trainable (ch:: AbstractArray , tr:: AbstractArray ) = tr
163
+ _trainable (ch:: Dict , tr:: Dict ) = merge (valuemap (_ -> nothing , ch), tr)
164
+
163
165
function _trainable (ch:: NamedTuple , tr:: Tuple ) # for old Flux-style no-names tuple
164
166
@warn " trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog= 3
165
167
map (c -> c in tr ? c : nothing , ch)
166
168
end
167
169
170
+
171
+ valuemap (f, x... ) = map (f, x... )
172
+ valuemap (f, x:: Dict , ys... ) = Dict (k => f (v, (get (y, k, nothing ) for y in ys). .. ) for (k,v) in x)
173
+ valueforeach (f, x... ) = foreach (f, x... )
174
+ valueforeach (f, x:: Dict , ys... ) = foreach (pairs (x)) do (k, v)
175
+ f (v, (get (y, k, nothing ) for y in ys). .. )
176
+ end
177
+
178
+
168
179
# ##
169
180
# ## rule definition helpers
170
181
# ##
0 commit comments