287287# Since this works like a zero-array in broadcasting, it should also accept a number:
288288(project:: ProjectTo{<:Tangent{<:Ref}} )(dx:: Number ) = project (Ref (dx))
289289
290- # Tuple
290+ # Tuple and NamedTuple
291291function ProjectTo (x:: Tuple )
292292 elements = map (ProjectTo, x)
293293 if elements isa NTuple{<: Any ,ProjectTo{<: AbstractZero }}
@@ -296,10 +296,22 @@ function ProjectTo(x::Tuple)
296296 return ProjectTo {Tangent{typeof(x)}} (; elements= elements)
297297 end
298298end
299+ function ProjectTo (x:: NamedTuple )
300+ elements = map (ProjectTo, x)
301+ if Tuple (elements) isa NTuple{<: Any ,ProjectTo{<: AbstractZero }}
302+ return ProjectTo {NoTangent} ()
303+ else
304+ return ProjectTo {Tangent{typeof(x)}} (; elements... )
305+ end
306+ end
307+
299308# This method means that projection is re-applied to the contents of a Tangent.
300309# We're not entirely sure whether this is every necessary; but it should be safe,
301310# and should often compile away:
302- (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: Tangent ) = project (backing (dx))
311+ function (project:: ProjectTo{<:Tangent{<:Union{Tuple,NamedTuple}}} )(dx:: Tangent )
312+ return project (backing (dx))
313+ end
314+
303315function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: Tuple )
304316 len = length (project. elements)
305317 if length (dx) != len
@@ -310,6 +322,45 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
310322 dy = map ((f, x) -> f (x), project. elements, dx)
311323 return project_type (project)(dy... )
312324end
325+ function (project:: ProjectTo{<:Tangent{<:NamedTuple}} )(dx:: NamedTuple )
326+ dy = _project_namedtuple (backing (project), dx)
327+ return project_type (project)(; dy... )
328+ end
329+
330+ # Diffractor returns not necessarily a named tuple with all keys and of the same order as
331+ # the projector
332+ # Thus we can't use `map`
333+ function _project_namedtuple (f:: NamedTuple{fn,ft} , x:: NamedTuple{xn,xt} ) where {fn,ft,xn,xt}
334+ if @generated
335+ vals = Any[
336+ if xn[i] in fn
337+ :(getfield (f, $ (QuoteNode (xn[i])))(getfield (x, $ (QuoteNode (xn[i])))))
338+ else
339+ throw (
340+ ArgumentError (
341+ " named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i]) " ,
342+ ),
343+ )
344+ end for i in 1 : length (xn)
345+ ]
346+ :(NamedTuple {$xn} (($ (vals... ),)))
347+ else
348+ vals = ntuple (Val (length (xn))) do i
349+ name = xn[i]
350+ if name in fn
351+ getfield (f, name)(getfield (x, name))
352+ else
353+ throw (
354+ ArgumentError (
355+ " named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i]) " ,
356+ ),
357+ )
358+ end
359+ end
360+ NamedTuple {xn} (vals)
361+ end
362+ end
363+
313364function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: AbstractArray )
314365 for d in 1 : ndims (dx)
315366 if size (dx, d) != get (length (project. elements), d, 1 )
0 commit comments