42
42
# used on unknown structs, but useful for handling many known ones in the same manner.
43
43
function generic_projector (x:: T ; kw... ) where {T}
44
44
fields_nt:: NamedTuple = backing (x)
45
- fields_proj = map (fields_nt) do x1
46
- if x1 isa Number || x1 isa AbstractArray
47
- ProjectTo (x1)
48
- else
49
- # This stores things like Symbols & functions verbatim,
50
- # and _maybe_project below keeps these in the reconstructed result.
51
- x1
52
- end
53
- end
45
+ fields_proj = map (_maybe_projector, fields_nt)
54
46
# We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
55
47
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
56
48
# but if it doesn't `construct` will give a good error message.
61
53
function generic_projection (project:: ProjectTo{T} , dx:: T ) where {T}
62
54
sub_projects = backing (project)
63
55
sub_dxs = backing (dx)
64
- return construct (T, map (_maybe_project , sub_projects, sub_dxs))
56
+ return construct (T, map (_maybe_call , sub_projects, sub_dxs))
65
57
end
66
58
67
59
function (project:: ProjectTo{T} )(dx:: Tangent ) where {T}
68
60
sub_projects = backing (project)
69
61
sub_dxs = backing (canonicalize (dx))
70
- return construct (T, map (_maybe_project , sub_projects, sub_dxs))
62
+ return construct (T, map (_maybe_call , sub_projects, sub_dxs))
71
63
end
72
64
73
- _maybe_project (f:: ProjectTo , x) = f (x)
74
- _maybe_project (f, x) = f
65
+ # Used for encoding fields, leaves alone non-diff types:
66
+ _maybe_projector (x:: Union{AbstractArray, Number, Ref} ) = ProjectTo (x)
67
+ _maybe_projector (x) = x
68
+ # Used for re-constructing fields, restores non-diff types:
69
+ _maybe_call (f:: ProjectTo , x) = f (x)
70
+ _maybe_call (f, x) = f
71
+
72
+ # Used for elements of e.g. Array{Any}, trivial projector
73
+ _always_projector (x:: Union{AbstractArray, Number, Ref} ) = ProjectTo (x)
74
+ _always_projector (x) = ProjectTo ()
75
75
76
76
"""
77
77
ProjectTo(x)
@@ -151,50 +151,48 @@ ProjectTo(x::Complex{<:Integer}) = ProjectTo(float(x))
151
151
# Arrays
152
152
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
153
153
# no structure worth re-imposing. Then any array is acceptable as a gradient.
154
+
155
+ # For arrays of numbers, just store one projector:
154
156
function ProjectTo (x:: AbstractArray{T} ) where {T<: Number }
155
- # For arrays of numbers, just store one projector:
156
157
element = ProjectTo (zero (T))
157
- # If all our elements are going to zero, then we can short circuit and just send the whole thing
158
- element isa ProjectTo{<: AbstractZero } && return element
159
- return ProjectTo {AbstractArray} (; element= element, axes= axes (x))
158
+ if element isa ProjectTo{<: AbstractZero }
159
+ return ProjectTo {NoTangent} () # short-circuit if all elements project to zero
160
+ else
161
+ return ProjectTo {AbstractArray} (; element= element, axes= axes (x))
162
+ end
160
163
end
164
+
165
+ # In other cases, store a projector per element:
161
166
function ProjectTo (xs:: AbstractArray )
162
- # Especially for arrays of arrays, we will store a projector per element:
163
- elements = map (xs) do x
164
- if x isa Number || x isa AbstractArray
165
- ProjectTo (x)
166
- else
167
- ProjectTo ()
168
- end
169
- end
167
+ elements = map (_always_projector, xs)
170
168
if elements isa AbstractArray{<: ProjectTo{<:AbstractZero} }
171
- return ProjectTo {NoTangent} ()
169
+ return ProjectTo {NoTangent} () # short-circuit if all elements project to zero
172
170
elseif elements isa AbstractArray{<: ProjectTo{Any} }
173
- return ProjectTo {AbstractArray} (; element= ProjectTo (), axes= axes (xs))
171
+ return ProjectTo {AbstractArray} (; element= ProjectTo (), axes= axes (xs)) # ... or none project
174
172
else
175
- # They will be individually applied :
173
+ # Arrays of arrays come here, and will apply projectors individually :
176
174
return ProjectTo {AbstractArray} (; elements= elements, axes= axes (xs))
177
175
end
178
176
end
177
+
179
178
function (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
179
+ # First deal with shape. The rule is that we reshape to add or remove trivial dimensions
180
+ # like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
180
181
dy = if axes (dx) == project. axes
181
182
dx
182
183
else
183
- # The rule here is that we reshape to add or remove trivial dimensions like dx = ones(4,1),
184
- # where x = ones(4), but throw an error on dx = ones(1,4) etc.
185
184
for d in 1 : max (M, length (project. axes))
186
185
size (dx, d) == length (get (project. axes, d, 1 )) || throw (_projection_mismatch (project. axes, size (dx)))
187
186
end
188
187
reshape (dx, project. axes)
189
188
end
189
+ # Then deal with the elements. One projector if AbstractArray{<:Number},
190
+ # or one per element for arrays of arrays:
190
191
dz = if hasfield (typeof (backing (project)), :element )
191
- # Easy case, like AbstractArray{<:Number}, fix eltype if necessary
192
192
T = project_type (project. element)
193
193
S <: T ? dy : map (project. element, dy)
194
- elseif hasfield (typeof (backing (project)), :elements )
195
- map ((f,y) -> f (y), project. elements, dy)
196
194
else
197
- throw ( ArgumentError ( " bad ProjectTo{AbstractArray} -- it should always have .element or . elements" ) )
195
+ map ((f,y) -> f (y), project . elements, dy )
198
196
end
199
197
return dz
200
198
end
@@ -206,20 +204,9 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
206
204
return fill (project. element (dx))
207
205
end
208
206
209
- # Ref -- works like a zero-array:
210
- ProjectTo (x:: Ref{<:Number} ) = ProjectTo {Ref} (; x = ProjectTo (getindex (x)))
211
- ProjectTo (x:: Ref{<:AbstractArray} ) = ProjectTo {Ref} (; x = ProjectTo (getindex (x)))
212
- function ProjectTo (x:: Ref )
213
- return if ! isassigned (x)
214
- ProjectTo {Ref} (; x = ProjectTo ())
215
- elseif x[] isa Number || x[] isa AbstractArray
216
- ProjectTo {Ref} (; x = ProjectTo (x[]))
217
- else
218
- ProjectTo {Ref} (; x = ProjectTo ())
219
- end
220
- end
207
+ # Ref -- works like a zero-array, allowss restoration from a number:
208
+ ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x = _always_projector (x[]))
221
209
(project:: ProjectTo{Ref} )(dx:: Ref ) = Ref (project. x (dx[]))
222
- # And like zero-dim arrays, allow restoration from a number:
223
210
(project:: ProjectTo{Ref} )(dx:: Number ) = Ref (project. x (dx))
224
211
225
212
function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
@@ -238,7 +225,7 @@ function ProjectTo(x::LinearAlgebra.AdjointAbsVec{T}) where {T<:Number}
238
225
end
239
226
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
240
227
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
241
- # but while row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
228
+ # but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
242
229
(project:: ProjectTo{Adjoint} )(dx:: Adjoint ) = adjoint (project. parent (parent (dx)))
243
230
(project:: ProjectTo{Adjoint} )(dx:: Transpose ) = adjoint (conj (project. parent (parent (dx)))) # might copy twice?
244
231
function (project:: ProjectTo{Adjoint} )(dx:: AbstractArray )
0 commit comments