@@ -66,7 +66,7 @@ function _flatten(x)
66
66
isnumeric (x) && return vcat (_vec (x)), 0 , length (x) # trivial case
67
67
arrays = AbstractVector[]
68
68
len = Ref (0 )
69
- off = fmap (x; exclude = isnumeric, walk = (f, z) -> map (f, _trainable (z) )) do y
69
+ off = fmap (x; exclude = isnumeric, walk = _TrainableStructWalk ( )) do y
70
70
push! (arrays, _vec (y))
71
71
o = len[]
72
72
len[] = o + length (y)
@@ -76,18 +76,22 @@ function _flatten(x)
76
76
reduce (vcat, arrays), off, len[]
77
77
end
78
78
79
+ struct _TrainableStructWalk <: AbstractWalk end
80
+
81
+ (:: _TrainableStructWalk )(recurse, x) = map (recurse, _trainable (x))
82
+
79
83
_vec (x:: Number ) = LinRange (x,x,1 )
80
84
_vec (x:: AbstractArray ) = vec (x)
81
85
82
86
function ChainRulesCore. rrule (:: typeof (_flatten), x)
83
87
flat, off, len = _flatten (x)
84
88
_maybewarn ()
85
- _flatten_back ((dflat, _, _)) = (NoT, _rebuild (x, off, unthunk (dflat), len; walk = _Tangent_biwalk, prune = NoT))
89
+ _flatten_back ((dflat, _, _)) = (NoT, _rebuild (x, off, unthunk (dflat), len; walk = _Tangent_biwalk () , prune = NoT))
86
90
(flat, off, len), _flatten_back
87
91
end
88
92
89
93
# This reconstructs either a model like x, or a gradient for it:
90
- function _rebuild (x, off, flat:: AbstractVector , len = length (flat); walk = _trainable_biwalk , kw... )
94
+ function _rebuild (x, off, flat:: AbstractVector , len = length (flat); walk = _Trainable_biwalk () , kw... )
91
95
len == length (flat) || throw (DimensionMismatch (" Rebuild expected a vector of length $len , got $(length (flat)) " ))
92
96
fmap (x, off; exclude = isnumeric, walk, kw... ) do y, o
93
97
_getat (y, o, flat)
@@ -98,7 +102,9 @@ _getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
98
102
_getat (y:: AbstractArray , o:: Int , flat:: AbstractVector ) =
99
103
ProjectTo (y)(reshape (flat[o .+ (1 : length (y))], axes (y))) # ProjectTo is just correcting eltypes
100
104
101
- function _trainable_biwalk (f, x, aux)
105
+ struct _Trainable_biwalk <: AbstractWalk end
106
+
107
+ function (:: _Trainable_biwalk )(f, x, aux)
102
108
ch, re = functor (typeof (x), x)
103
109
au, _ = functor (typeof (x), aux)
104
110
_trainmap (f, ch, _trainable (x), au) |> re
@@ -110,7 +116,9 @@ function _trainmap(f, ch, tr, aux)
110
116
end
111
117
end
112
118
113
- function _Tangent_biwalk (f, x, aux) # use with prune = NoT
119
+ struct _Tangent_biwalk <: AbstractWalk end
120
+
121
+ function (:: _Tangent_biwalk )(f, x, aux) # use with prune = NoT
114
122
ch, re = functor (typeof (x), x)
115
123
au, _ = functor (typeof (x), aux)
116
124
y = _trainmap (f, ch, _trainable (x), au)
@@ -156,7 +164,7 @@ _grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity
156
164
# These are only needed for 2nd derivatives:
157
165
function ChainRulesCore. rrule (:: typeof (_grad!), x, dx, off, flat)
158
166
@warn " second derivatives of Restructure may not work yet, sorry!" maxlog= 3
159
- _grad_back (dflat) = (NoT, NoT, _rebuild (x, off, unthunk (dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
167
+ _grad_back (dflat) = (NoT, NoT, _rebuild (x, off, unthunk (dflat); walk = _Tangent_biwalk () , prune = NoT), NoT, NoT)
160
168
_grad! (x, dx, off, flat), _grad_back
161
169
end
162
170
base (dx:: Tangent{<:Tangent} ) = backing (dx). backing # might be needed for gradient(gradient(destructure))
0 commit comments