Skip to content

Commit 32cab7e

Browse files
mcabbottoxinabox
authored andcommitted
diagonal etc
1 parent db51820 commit 32cab7e

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

src/projection.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
(p::ProjectTo{T})(dx)
33
44
Projects the differential `dx` onto a specific tangent space.
5-
This guarantees `p(dx)::T`, except for allowing `dx::AbstractZero` to pass through.
5+
6+
The type `T` is meant to encode the largest acceptable space, so usually
7+
this enforces `p(dx)::T`. But some subspaces which aren't subtypes of `T` may
8+
be allowed, and in particular `dx::AbstractZero` always passes through.
69
710
Usually `T` is the "outermost" part of the type, and `p` stores additional
811
properties such as projectors for each constituent field.
@@ -245,23 +248,19 @@ end
245248

246249
# Diagonal
247250
function ProjectTo(x::Diagonal)
248-
eltype(x) == Bool && return ProjectTo(false)
249-
sub = ProjectTo(get_diag(x))
250-
sub isa ProjectTo{<:AbstractZero} && return sub
251+
sub = ProjectTo(x.diag)
252+
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if Diagonal(NoTangent()) worked
251253
return ProjectTo{Diagonal}(; diag=sub)
252254
end
253-
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(get_diag(dx)))
254-
255-
get_diag(x) = diag(x)
256-
get_diag(x::Diagonal) = x.diag
255+
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
256+
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
257257

258258
# Symmetric
259259
for (SymHerm, chk, fun) in ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint))
260260
@eval begin
261261
function ProjectTo(x::$SymHerm)
262-
eltype(x) == Bool && return ProjectTo(false)
263262
sub = ProjectTo(parent(x))
264-
sub isa ProjectTo{<:AbstractZero} && return sub
263+
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if Hermitian(NoTangent()) etc. worked
265264
return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub)
266265
end
267266
function (project::ProjectTo{$SymHerm})(dx::AbstractArray)
@@ -285,12 +284,16 @@ end
285284
for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg
286285
@eval begin
287286
function ProjectTo(x::$UL)
288-
eltype(x) == Bool && return ProjectTo(false)
289287
sub = ProjectTo(parent(x))
290-
sub isa ProjectTo{<:AbstractZero} && return sub
288+
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if UnitUpperTriangular(NoTangent()) etc. worked
291289
return ProjectTo{$UL}(; parent=sub)
292290
end
293291
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx))
292+
function (project::ProjectTo{$UL})(dx::Diagonal)
293+
sub = project.parent
294+
sub_one = ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],))
295+
return Diagonal(sub_one(dx.diag))
296+
end
294297
end
295298
end
296299

test/projection.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ using OffsetArrays, BenchmarkTools
116116
@test pupp(rand(ComplexF32, 3, 3, 1)) isa UpperTriangular{Float64}
117117
@test ProjectTo(UpperTriangular(randn(3,3) .> 0))(randn(3,3)) == NoTangent()
118118

119+
# some subspaces which aren't subtypes
120+
@test psymm(Diagonal([1,2,3])) isa Diagonal{Float64}
121+
@test pupp(Diagonal([1,2,3+4im])) isa Diagonal{Float64}
122+
119123
# structured matrices with linear-size backing
120124
pdiag = ProjectTo(Diagonal(1:3))
121125
@test pdiag(reshape(1:9,3,3)) == Diagonal([1,5,9])

0 commit comments

Comments
 (0)