@@ -59,10 +59,8 @@ function generic_projector(x::T; kw...) where {T}
5959 fields_nt:: NamedTuple = backing (x)
6060 fields_proj = map (_maybe_projector, fields_nt)
6161 # We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
62- # `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
63- # but if it doesn't `construct` will give a good error message.
62+ # `Foo{Diagaonal{E}}` etc. Official API for this? https://github.com/JuliaLang/julia/issues/35543
6463 wrapT = T. name. wrapper
65- # Official API for this? https://github.com/JuliaLang/julia/issues/35543
6664 return ProjectTo {wrapT} (; fields_proj... , kw... )
6765end
6866
@@ -72,12 +70,6 @@ function generic_projection(project::ProjectTo{T}, dx::T) where {T}
7270 return construct (T, map (_maybe_call, sub_projects, sub_dxs))
7371end
7472
75- function (project:: ProjectTo{T} )(dx:: Tangent ) where {T}
76- sub_projects = backing (project)
77- sub_dxs = backing (canonicalize (dx))
78- return construct (T, map (_maybe_call, sub_projects, sub_dxs))
79- end
80-
8173# Used for encoding fields, leaves alone non-diff types:
8274_maybe_projector (x:: Union{AbstractArray,Number,Ref} ) = ProjectTo (x)
8375_maybe_projector (x) = x
@@ -123,7 +115,6 @@ ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2),
123115ProjectTo (:: Any ) # just to attach docstring
124116
125117# Generic
126- (:: ProjectTo{T} )(dx:: T ) where {T} = dx # not always correct but we have special cases for when it isn't
127118(:: ProjectTo{T} )(dx:: AbstractZero ) where {T} = dx
128119(:: ProjectTo{T} )(dx:: NotImplemented ) where {T} = dx
129120
@@ -133,7 +124,17 @@ ProjectTo(::Any) # just to attach docstring
133124# Zero
134125ProjectTo (:: AbstractZero ) = ProjectTo {NoTangent} () # Any x::Zero in forward pass makes this one projector,
135126(:: ProjectTo{NoTangent} )(dx) = NoTangent () # but this is the projection only for nonzero gradients,
136- (:: ProjectTo{NoTangent} )(:: NoTangent ) = NoTangent () # and this one solves an ambiguity.
127+ (:: ProjectTo{NoTangent} )(dx:: AbstractZero ) = dx # and this one solves an ambiguity.
128+
129+ # Also, any explicit construction with fields, where all fields project to zero, itself
130+ # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
131+ const _PZ = ProjectTo{<: AbstractZero }
132+ ProjectTo {P} (:: NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}} ) where {P,T} = ProjectTo {NoTangent} ()
133+
134+ # Tangent
135+ # We haven't entirely figured out when to convert Tangents to "natural" representations such as
136+ # dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
137+ (:: ProjectTo{T} )(dx:: Tangent{<:T} ) where {T} = dx
137138
138139# ####
139140# #### `Base`
@@ -165,27 +166,29 @@ end
165166(:: ProjectTo{T} )(dx:: Integer ) where {T<: Complex{<:AbstractFloat} } = convert (T, dx)
166167
167168# Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through.
168- # We assume (lacking evidence to the contrary) that it is the right subspace of numebers
169- # The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different
170- # Number type that might not be a subtype of the `project_type`.
169+ # We assume (lacking evidence to the contrary) that it is the right subspace of numebers.
171170(:: ProjectTo{<:Number} )(dx:: Number ) = dx
172171
173172(project:: ProjectTo{<:Real} )(dx:: Complex ) = project (real (dx))
174173(project:: ProjectTo{<:Complex} )(dx:: Real ) = project (complex (dx))
175174
175+ # Tangents: we prefer to reconstruct numbers, but only safe to try when their constructor
176+ # understands, including a mix of Zeros & reals. Other cases, we just let through:
177+ (project:: ProjectTo{<:Complex} )(dx:: Tangent{<:Complex} ) = project (Complex (dx. re, dx. im))
178+ (:: ProjectTo{<:Number} )(dx:: Tangent{<:Number} ) = dx
179+
176180# Arrays
177181# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
178182# no structure worth re-imposing. Then any array is acceptable as a gradient.
179183
180184# For arrays of numbers, just store one projector:
181185function ProjectTo (x:: AbstractArray{T} ) where {T<: Number }
182- element = T <: Irrational ? ProjectTo {Real} () : ProjectTo (zero (T))
183- if element isa ProjectTo{<: AbstractZero }
184- return ProjectTo {NoTangent} () # short-circuit if all elements project to zero
185- else
186- return ProjectTo {AbstractArray} (; element= element, axes= axes (x))
187- end
186+ return ProjectTo {AbstractArray} (; element= _eltype_projectto (T), axes= axes (x))
188187end
188+ ProjectTo (x:: AbstractArray{Bool} ) = ProjectTo {NoTangent} ()
189+
190+ _eltype_projectto (:: Type{T} ) where {T<: Number } = ProjectTo (zero (T))
191+ _eltype_projectto (:: Type{<:Irrational} ) = ProjectTo {Real} ()
189192
190193# In other cases, store a projector per element:
191194function ProjectTo (xs:: AbstractArray )
@@ -241,27 +244,39 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
241244 return fill (project. element (dx))
242245end
243246
244- # Ref -- works like a zero-array, also allows restoration from a number:
245- ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x= ProjectTo (x[]))
246- (project:: ProjectTo{Ref} )(dx:: Ref ) = Ref (project. x (dx[]))
247- (project:: ProjectTo{Ref} )(dx:: Number ) = Ref (project. x (dx))
248-
249247function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
250248 size_x = map (length, axes_x)
251249 return DimensionMismatch (
252250 " variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx "
253251 )
254252end
255253
254+ # ####
255+ # #### `Base`, part II: return of the Tangent
256+ # ####
257+
258+ # Ref
259+ function ProjectTo (x:: Ref )
260+ sub = ProjectTo (x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
261+ if sub isa ProjectTo{<: AbstractZero }
262+ return ProjectTo {NoTangent} ()
263+ else
264+ return ProjectTo {Ref} (; type= typeof (x), x= sub)
265+ end
266+ end
267+ (project:: ProjectTo{Ref} )(dx:: Tangent{<:Ref} ) = Tangent {project.type} (; x= project. x (dx. x))
268+ (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.type} (; x= project. x (dx[]))
269+ # Since this works like a zero-array in broadcasting, it should also accept a number:
270+ (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.type} (; x= project. x (dx))
271+
256272# ####
257273# #### `LinearAlgebra`
258274# ####
259275
276+ using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
277+
260278# Row vectors
261- function ProjectTo (x:: LinearAlgebra.AdjointAbsVec )
262- sub = ProjectTo (parent (x))
263- return ProjectTo {Adjoint} (; parent= sub)
264- end
279+ ProjectTo (x:: AdjointAbsVec ) = ProjectTo {Adjoint} (; parent= ProjectTo (parent (x)))
265280# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
266281# Transposed matrices are, like PermutedDimsArray, just a storage detail,
267282# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
@@ -276,10 +291,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
276291 return adjoint (project. parent (dy))
277292end
278293
279- function ProjectTo (x:: LinearAlgebra.TransposeAbsVec )
280- sub = ProjectTo (parent (x))
281- return ProjectTo {Transpose} (; parent= sub)
282- end
294+ ProjectTo (x:: LinearAlgebra.TransposeAbsVec ) = ProjectTo {Transpose} (; parent= ProjectTo (parent (x)))
283295function (project:: ProjectTo{Transpose} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
284296 return transpose (project. parent (transpose (dx)))
285297end
@@ -292,11 +304,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
292304end
293305
294306# Diagonal
295- function ProjectTo (x:: Diagonal )
296- sub = ProjectTo (x. diag)
297- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if Diagonal(NoTangent()) worked
298- return ProjectTo {Diagonal} (; diag= sub)
299- end
307+ ProjectTo (x:: Diagonal ) = ProjectTo {Diagonal} (; diag= ProjectTo (x. diag))
300308(project:: ProjectTo{Diagonal} )(dx:: AbstractMatrix ) = Diagonal (project. diag (diag (dx)))
301309(project:: ProjectTo{Diagonal} )(dx:: Diagonal ) = Diagonal (project. diag (dx. diag))
302310
@@ -308,7 +316,8 @@ for (SymHerm, chk, fun) in (
308316 @eval begin
309317 function ProjectTo (x:: $SymHerm )
310318 sub = ProjectTo (parent (x))
311- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if Hermitian(NoTangent()) etc. worked
319+ # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial:
320+ sub isa ProjectTo{<: AbstractZero } && return sub
312321 return ProjectTo {$SymHerm} (; uplo= LinearAlgebra. sym_uplo (x. uplo), parent= sub)
313322 end
314323 function (project:: ProjectTo{$SymHerm} )(dx:: AbstractArray )
333342# Triangular
334343for UL in (:UpperTriangular , :LowerTriangular , :UnitUpperTriangular , :UnitLowerTriangular ) # UpperHessenberg
335344 @eval begin
336- function ProjectTo (x:: $UL )
337- sub = ProjectTo (parent (x))
338- # TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
339- sub isa ProjectTo{<: AbstractZero } && return sub
340- return ProjectTo {$UL} (; parent= sub)
341- end
345+ ProjectTo (x:: $UL ) = ProjectTo {$UL} (; parent= ProjectTo (parent (x)))
342346 (project:: ProjectTo{$UL} )(dx:: AbstractArray ) = $ UL (project. parent (dx))
343347 function (project:: ProjectTo{$UL} )(dx:: Diagonal )
344348 sub = project. parent
0 commit comments