|
| 1 | +""" |
| 2 | + ProjectTo(x::T) |
| 3 | +
|
| 4 | +Returns a `ProjectTo{T,...}` functor able to project a differential `dx` onto the same tangent space as `x`. |
| 5 | +This functor encloses over what ever is needed to be able to be able to do that projection. |
| 6 | +For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x` |
| 7 | +is not available from `P`, so it is stored in the functor. |
| 8 | +
|
| 9 | + (::ProjectTo{T})(dx) |
| 10 | +
|
| 11 | +Projects the differential `dx` on the onto the tangent space used to create the `ProjectTo`. |
| 12 | +""" |
| 13 | +struct ProjectTo{P,D<:NamedTuple} |
| 14 | + info::D |
| 15 | +end |
| 16 | +ProjectTo{P}(info::D) where {P,D<:NamedTuple} = ProjectTo{P,D}(info) |
| 17 | +ProjectTo{P}(; kwargs...) where {P} = ProjectTo{P}(NamedTuple(kwargs)) |
| 18 | + |
| 19 | +backing(project::ProjectTo) = getfield(project, :info) |
| 20 | +Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name) |
| 21 | +Base.propertynames(p::ProjectTo) = propertynames(backing(p)) |
| 22 | + |
| 23 | +function Base.show(io::IO, project::ProjectTo{T}) where {T} |
| 24 | + print(io, "ProjectTo{") |
| 25 | + show(io, T) |
| 26 | + print(io, "}") |
| 27 | + if isempty(backing(project)) |
| 28 | + print(io, "()") |
| 29 | + else |
| 30 | + show(io, backing(project)) |
| 31 | + end |
| 32 | +end |
| 33 | + |
| 34 | +# fallback (structs) |
| 35 | +function ProjectTo(x::T) where {T} |
| 36 | + # Generic fallback for structs, recursively make `ProjectTo`s all their fields |
| 37 | + fields_nt::NamedTuple = backing(x) |
| 38 | + return ProjectTo{T}(map(ProjectTo, fields_nt)) |
| 39 | +end |
| 40 | +function (project::ProjectTo{T})(dx::Tangent) where {T} |
| 41 | + sub_projects = backing(project) |
| 42 | + sub_dxs = backing(canonicalize(dx)) |
| 43 | + _call(f, x) = f(x) |
| 44 | + return construct(T, map(_call, sub_projects, sub_dxs)) |
| 45 | +end |
| 46 | + |
| 47 | +# should not work for Tuples and NamedTuples, as not valid tangent types |
| 48 | +function ProjectTo(x::T) where {T<:Union{<:Tuple,NamedTuple}} |
| 49 | + return throw( |
| 50 | + ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x") |
| 51 | + ) |
| 52 | +end |
| 53 | + |
| 54 | +# Generic |
| 55 | +(project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx)) |
| 56 | +(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't |
| 57 | +(::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T) |
| 58 | + |
| 59 | +# Number |
| 60 | +ProjectTo(::T) where {T<:Number} = ProjectTo{T}() |
| 61 | +(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx) |
| 62 | +(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx)) |
| 63 | + |
| 64 | +# Arrays |
| 65 | +ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs)) |
| 66 | +function (project::ProjectTo{T})(dx::Array) where {T<:Array} |
| 67 | + _call(f, x) = f(x) |
| 68 | + return T(map(_call, project.elements, dx)) |
| 69 | +end |
| 70 | +function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array} |
| 71 | + return T(map(proj -> proj(dx), project.elements)) |
| 72 | +end |
| 73 | +(project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx)) |
| 74 | + |
| 75 | +# Arrays{<:Number}: optimized case so we don't need a projector per element |
| 76 | +function ProjectTo(x::T) where {E<:Number,T<:Array{E}} |
| 77 | + return ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x)) |
| 78 | +end |
| 79 | +(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx) |
| 80 | +function (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} |
| 81 | + return zeros(T, project.size) |
| 82 | +end |
| 83 | +function (project::ProjectTo{<:Array{T}})(dx::Tangent{<:SubArray}) where {T<:Number} |
| 84 | + return project(dx.parent) |
| 85 | +end |
| 86 | + |
| 87 | +# Diagonal |
| 88 | +ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) |
| 89 | +(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx))) |
| 90 | +(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx)) |
| 91 | + |
| 92 | +# :data, :uplo fields |
| 93 | +for SymHerm in (:Symmetric, :Hermitian) |
| 94 | + @eval begin |
| 95 | + function ProjectTo(x::T) where {T<:$SymHerm} |
| 96 | + return ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) |
| 97 | + end |
| 98 | + function (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix) |
| 99 | + return $SymHerm(project.parent(dx), project.uplo) |
| 100 | + end |
| 101 | + function (project::ProjectTo{<:$SymHerm})(dx::AbstractZero) |
| 102 | + return $SymHerm(project.parent(dx), project.uplo) |
| 103 | + end |
| 104 | + function (project::ProjectTo{<:$SymHerm})(dx::Tangent) |
| 105 | + return $SymHerm(project.parent(dx.data), project.uplo) |
| 106 | + end |
| 107 | + end |
| 108 | +end |
| 109 | + |
| 110 | +# :data field |
| 111 | +for UL in (:UpperTriangular, :LowerTriangular) |
| 112 | + @eval begin |
| 113 | + ProjectTo(x::T) where {T<:$UL} = ProjectTo{T}(; parent=ProjectTo(parent(x))) |
| 114 | + (project::ProjectTo{<:$UL})(dx::AbstractMatrix) = $UL(project.parent(dx)) |
| 115 | + (project::ProjectTo{<:$UL})(dx::AbstractZero) = $UL(project.parent(dx)) |
| 116 | + (project::ProjectTo{<:$UL})(dx::Tangent) = $UL(project.parent(dx.data)) |
| 117 | + end |
| 118 | +end |
| 119 | + |
| 120 | +# Transpose |
| 121 | +ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x))) |
| 122 | +function (project::ProjectTo{<:Transpose})(dx::AbstractMatrix) |
| 123 | + return transpose(project.parent(transpose(dx))) |
| 124 | +end |
| 125 | +(project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx)) |
| 126 | + |
| 127 | +# Adjoint |
| 128 | +ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x))) |
| 129 | +(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.parent(adjoint(dx))) |
| 130 | +(project::ProjectTo{<:Adjoint})(dx::AbstractZero) = adjoint(project.parent(dx)) |
| 131 | + |
| 132 | +# PermutedDimsArray |
| 133 | +ProjectTo(x::P) where {P<:PermutedDimsArray} = ProjectTo{P}(; parent=ProjectTo(parent(x))) |
| 134 | +function (project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})( |
| 135 | + dx::AbstractArray |
| 136 | +) where {T,N,perm,iperm,AA} |
| 137 | + return PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm)) |
| 138 | +end |
| 139 | +function (project::ProjectTo{P})(dx::AbstractZero) where {P<:PermutedDimsArray} |
| 140 | + return P(project.parent(dx)) |
| 141 | +end |
| 142 | + |
| 143 | +# SubArray |
| 144 | +ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy |
0 commit comments