@@ -40,6 +40,17 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
4040backing (project:: ProjectTo ) = getfield (project, :info )
4141
4242project_type (p:: ProjectTo{T} ) where {T} = T
43+ project_eltype (p:: ProjectTo{T} ) where {T} = eltype (T)
44+
45+ function project_promote_type (projectors)
46+ T = mapreduce (project_type, promote_type, projectors)
47+ if T <: Number
48+ # The point of this function is to make p.element for arrays. Not in use yet!
49+ return ProjectTo (zero (T))
50+ else
51+ return ProjectTo {Any} ()
52+ end
53+ end
4354
4455function Base. show (io:: IO , project:: ProjectTo{T} ) where {T}
4556 print (io, " ProjectTo{" )
181192# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
182193# no structure worth re-imposing. Then any array is acceptable as a gradient.
183194
184- # For arrays of numbers, just store one projector:
185- function ProjectTo (x:: AbstractArray{T} ) where {T<: Number }
186- return ProjectTo {AbstractArray} (; element= _eltype_projectto (T), axes= axes (x))
195+ # For arrays of numbers, just store one projector, and construct it without branches:
196+ ProjectTo (x:: AbstractArray{<:Number} ) = _array_projectto (x, axes (x))
197+ function _array_projectto (x:: AbstractArray{T,N} , axes:: NTuple{N,<:Base.OneTo{Int}} ) where {T,N}
198+ element = _eltype_projectto (T)
199+ S = project_type (element)
200+ # Fastest path: N means they are OneTo, hence reshape can be skipped
201+ return ProjectTo {AbstractArray{S,N}} (; element= element, axes= axes)
202+ end
203+ function _array_projectto (x:: AbstractArray{T,N} , axes:: Tuple ) where {T,N}
204+ element = _eltype_projectto (T)
205+ S = project_type (element)
206+ # Omitting N means reshape will be called, for OffsetArrays, SArrays, etc.
207+ return ProjectTo {AbstractArray{S}} (; element= element, axes= axes)
187208end
188209ProjectTo (x:: AbstractArray{Bool} ) = ProjectTo {NoTangent} ()
189210
@@ -201,7 +222,7 @@ function ProjectTo(xs::AbstractArray)
201222 end
202223end
203224
204- function (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
225+ function (project:: ProjectTo{<: AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
205226 # First deal with shape. The rule is that we reshape to add or remove trivial dimensions
206227 # like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
207228 dy = if axes (dx) == project. axes
@@ -225,24 +246,34 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
225246 return dz
226247end
227248
249+ # Fast paths, for arrays of numbers:
250+ (:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S,N} ) where {S<: T } where {T,N} = dx
251+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S} ) where {S<: T } where {T,N} = reshape (dx, project. axes)
252+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S,N} ) where {S,T,N} = map (project. element, dx)
253+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray ) where {T,N} = map (project. element, reshape (dx, project. axes))
254+
228255# Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK.
229- (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{<:AbstractZero} ) = NoTangent ()
256+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{<:AbstractZero} ) where {T,N} = NoTangent ()
230257
231258# Row vectors aren't acceptable as gradients for 1-row matrices:
232- function (project:: ProjectTo{AbstractArray} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
259+ # function (project::ProjectTo{<:AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
260+ # return project(reshape(vec(dx), 1, :))
261+ # end
262+ function (project:: ProjectTo{AbstractArray{T,N}} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) where {T,N}
233263 return project (reshape (vec (dx), 1 , :))
234264end
235265
236266# Zero-dimensional arrays -- these have a habit of going missing,
237267# although really Ref() is probably a better structure.
238- function (project:: ProjectTo{AbstractArray} )(dx:: Number ) # ... so we restore from numbers
239- if ! (project. axes isa Tuple{})
240- throw (DimensionMismatch (
241- " array with ndims(x) == $(length (project. axes)) > 0 cannot have dx::Number" ,
242- ))
243- end
244- return fill (project. element (dx))
245- end
268+ # function (project::ProjectTo{<:AbstractArray})(dx::Number) # ... so we restore from numbers
269+ # if !(project.axes isa Tuple{})
270+ # throw(DimensionMismatch(
271+ # "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
272+ # ))
273+ # end
274+ # return fill(project.element(dx))
275+ # end
276+ (project:: ProjectTo{AbstractArray{<:Number,0}} )(dx:: Number ) = fill (project. element (dx))
246277
247278function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
248279 size_x = map (length, axes_x)
0 commit comments