Skip to content

Commit 997a84d

Browse files
mcabbottoxinabox
authored andcommitted
tidy up using _maybe_projector, _always_projector
1 parent 008f4ef commit 997a84d

File tree

2 files changed

+38
-48
lines changed

2 files changed

+38
-48
lines changed

src/projection.jl

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,7 @@ end
4242
# used on unknown structs, but useful for handling many known ones in the same manner.
4343
function generic_projector(x::T; kw...) where {T}
4444
fields_nt::NamedTuple = backing(x)
45-
fields_proj = map(fields_nt) do x1
46-
if x1 isa Number || x1 isa AbstractArray
47-
ProjectTo(x1)
48-
else
49-
# This stores things like Symbols & functions verbatim,
50-
# and _maybe_project below keeps these in the reconstructed result.
51-
x1
52-
end
53-
end
45+
fields_proj = map(_maybe_projector, fields_nt)
5446
# We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
5547
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
5648
# but if it doesn't `construct` will give a good error message.
@@ -61,17 +53,25 @@ end
6153
function generic_projection(project::ProjectTo{T}, dx::T) where {T}
6254
sub_projects = backing(project)
6355
sub_dxs = backing(dx)
64-
return construct(T, map(_maybe_project, sub_projects, sub_dxs))
56+
return construct(T, map(_maybe_call, sub_projects, sub_dxs))
6557
end
6658

6759
function (project::ProjectTo{T})(dx::Tangent) where {T}
6860
sub_projects = backing(project)
6961
sub_dxs = backing(canonicalize(dx))
70-
return construct(T, map(_maybe_project, sub_projects, sub_dxs))
62+
return construct(T, map(_maybe_call, sub_projects, sub_dxs))
7163
end
7264

73-
_maybe_project(f::ProjectTo, x) = f(x)
74-
_maybe_project(f, x) = f
65+
# Used for encoding fields, leaves alone non-diff types:
66+
_maybe_projector(x::Union{AbstractArray, Number, Ref}) = ProjectTo(x)
67+
_maybe_projector(x) = x
68+
# Used for re-constructing fields, restores non-diff types:
69+
_maybe_call(f::ProjectTo, x) = f(x)
70+
_maybe_call(f, x) = f
71+
72+
# Used for elements of e.g. Array{Any}, trivial projector
73+
_always_projector(x::Union{AbstractArray, Number, Ref}) = ProjectTo(x)
74+
_always_projector(x) = ProjectTo()
7575

7676
"""
7777
ProjectTo(x)
@@ -151,50 +151,48 @@ ProjectTo(x::Complex{<:Integer}) = ProjectTo(float(x))
151151
# Arrays
152152
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
153153
# no structure worth re-imposing. Then any array is acceptable as a gradient.
154+
155+
# For arrays of numbers, just store one projector:
154156
function ProjectTo(x::AbstractArray{T}) where {T<:Number}
155-
# For arrays of numbers, just store one projector:
156157
element = ProjectTo(zero(T))
157-
# If all our elements are going to zero, then we can short circuit and just send the whole thing
158-
element isa ProjectTo{<:AbstractZero} && return element
159-
return ProjectTo{AbstractArray}(; element=element, axes=axes(x))
158+
if element isa ProjectTo{<:AbstractZero}
159+
return ProjectTo{NoTangent}() # short-circuit if all elements project to zero
160+
else
161+
return ProjectTo{AbstractArray}(; element=element, axes=axes(x))
162+
end
160163
end
164+
165+
# In other cases, store a projector per element:
161166
function ProjectTo(xs::AbstractArray)
162-
# Especially for arrays of arrays, we will store a projector per element:
163-
elements = map(xs) do x
164-
if x isa Number || x isa AbstractArray
165-
ProjectTo(x)
166-
else
167-
ProjectTo()
168-
end
169-
end
167+
elements = map(_always_projector, xs)
170168
if elements isa AbstractArray{<:ProjectTo{<:AbstractZero}}
171-
return ProjectTo{NoTangent}()
169+
return ProjectTo{NoTangent}() # short-circuit if all elements project to zero
172170
elseif elements isa AbstractArray{<:ProjectTo{Any}}
173-
return ProjectTo{AbstractArray}(; element=ProjectTo(), axes=axes(xs))
171+
return ProjectTo{AbstractArray}(; element=ProjectTo(), axes=axes(xs)) # ... or none project
174172
else
175-
# They will be individually applied:
173+
# Arrays of arrays come here, and will apply projectors individually:
176174
return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs))
177175
end
178176
end
177+
179178
function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
179+
# First deal with shape. The rule is that we reshape to add or remove trivial dimensions
180+
# like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
180181
dy = if axes(dx) == project.axes
181182
dx
182183
else
183-
# The rule here is that we reshape to add or remove trivial dimensions like dx = ones(4,1),
184-
# where x = ones(4), but throw an error on dx = ones(1,4) etc.
185184
for d in 1:max(M, length(project.axes))
186185
size(dx, d) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(project.axes, size(dx)))
187186
end
188187
reshape(dx, project.axes)
189188
end
189+
# Then deal with the elements. One projector if AbstractArray{<:Number},
190+
# or one per element for arrays of arrays:
190191
dz = if hasfield(typeof(backing(project)), :element)
191-
# Easy case, like AbstractArray{<:Number}, fix eltype if necessary
192192
T = project_type(project.element)
193193
S <: T ? dy : map(project.element, dy)
194-
elseif hasfield(typeof(backing(project)), :elements)
195-
map((f,y) -> f(y), project.elements, dy)
196194
else
197-
throw(ArgumentError("bad ProjectTo{AbstractArray} -- it should always have .element or .elements"))
195+
map((f,y) -> f(y), project.elements, dy)
198196
end
199197
return dz
200198
end
@@ -206,20 +204,9 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
206204
return fill(project.element(dx))
207205
end
208206

209-
# Ref -- works like a zero-array:
210-
ProjectTo(x::Ref{<:Number}) = ProjectTo{Ref}(; x = ProjectTo(getindex(x)))
211-
ProjectTo(x::Ref{<:AbstractArray}) = ProjectTo{Ref}(; x = ProjectTo(getindex(x)))
212-
function ProjectTo(x::Ref)
213-
return if !isassigned(x)
214-
ProjectTo{Ref}(; x = ProjectTo())
215-
elseif x[] isa Number || x[] isa AbstractArray
216-
ProjectTo{Ref}(; x = ProjectTo(x[]))
217-
else
218-
ProjectTo{Ref}(; x = ProjectTo())
219-
end
220-
end
207+
# Ref -- works like a zero-array, allowss restoration from a number:
208+
ProjectTo(x::Ref) = ProjectTo{Ref}(; x = _always_projector(x[]))
221209
(project::ProjectTo{Ref})(dx::Ref) = Ref(project.x(dx[]))
222-
# And like zero-dim arrays, allow restoration from a number:
223210
(project::ProjectTo{Ref})(dx::Number) = Ref(project.x(dx))
224211

225212
function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
@@ -238,7 +225,7 @@ function ProjectTo(x::LinearAlgebra.AdjointAbsVec{T}) where {T<:Number}
238225
end
239226
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
240227
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
241-
# but while row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
228+
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
242229
(project::ProjectTo{Adjoint})(dx::Adjoint) = adjoint(project.parent(parent(dx)))
243230
(project::ProjectTo{Adjoint})(dx::Transpose) = adjoint(conj(project.parent(parent(dx)))) # might copy twice?
244231
function (project::ProjectTo{Adjoint})(dx::AbstractArray)

test/projection.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ using OffsetArrays, BenchmarkTools
3434
@test pmat([1 2; 3 4.0 + 5im]') isa Adjoint # pass-through
3535
@test pmat([1 2; 3 4]') isa Matrix{ComplexF64} # broadcast type change
3636

37+
pmat2 = ProjectTo(rand(2,2)')
38+
@test pmat2([1 2; 3 4.0 + 5im]) isa Matrix # adjoint matrices are not preserved
39+
3740
# arrays of arrays
3841
pvecvec = ProjectTo([[1,2], [3,4,5]])
3942
@test pvecvec([1:2, 3:5])[1] == 1:2

0 commit comments

Comments
 (0)