Skip to content

Commit aaa748a

Browse files
authored
Merge branch 'main' into copilot/update-docstrings-eigenvalue-decompositions
2 parents cef025d + d082c7d commit aaa748a

File tree

8 files changed

+286
-48
lines changed

8 files changed

+286
-48
lines changed

src/MatrixAlgebraKit.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ module MatrixAlgebraKit
33
using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
55
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
6-
using LinearAlgebra: sylvester
6+
using LinearAlgebra: sylvester, lu!
77
using LinearAlgebra: isposdef, issymmetric
88
using LinearAlgebra: Diagonal, diag, diagind, isdiag
99
using LinearAlgebra: UpperTriangular, LowerTriangular
1010
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt
1111

1212
export isisometry, isunitary, ishermitian, isantihermitian
1313

14-
export project_hermitian, project_antihermitian
15-
export project_hermitian!, project_antihermitian!
14+
export project_hermitian, project_antihermitian, project_isometric
15+
export project_hermitian!, project_antihermitian!, project_isometric!
1616
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
1717
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
1818
export svd_compact, svd_full, svd_vals, svd_trunc
@@ -34,6 +34,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3434
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3535
LAPACK_DivideAndConquer, LAPACK_Jacobi
3636
export LQViaTransposedQR
37+
export PolarViaSVD, PolarNewton
3738
export DiagonalAlgorithm
3839
export NativeBlocked
3940
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar,

src/implementations/polar.jl

Lines changed: 141 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function check_input(::typeof(left_polar!), A::AbstractMatrix, WP, ::AbstractAlg
1111
@assert W isa AbstractMatrix && P isa AbstractMatrix
1212
@check_size(W, (m, n))
1313
@check_scalar(W, A)
14-
@check_size(P, (n, n))
14+
isempty(P) || @check_size(P, (n, n))
1515
@check_scalar(P, A)
1616
return nothing
1717
end
@@ -21,7 +21,7 @@ function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, ::Abstrac
2121
n >= m ||
2222
throw(ArgumentError("input matrix needs at least as many columns as rows"))
2323
@assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix
24-
@check_size(P, (m, m))
24+
isempty(P) || @check_size(P, (m, m))
2525
@check_scalar(P, A)
2626
@check_size(Wᴴ, (m, n))
2727
@check_scalar(Wᴴ, A)
@@ -43,25 +43,154 @@ function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::Abstract
4343
return (P, Wᴴ)
4444
end
4545

46-
# Implementation
47-
# --------------
46+
# Implementation via SVD
47+
# -----------------------
4848
function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD)
4949
check_input(left_polar!, A, WP, alg)
50-
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
50+
U, S, Vᴴ = svd_compact!(A, alg.svd_alg)
5151
W, P = WP
5252
W = mul!(W, U, Vᴴ)
53-
S .= sqrt.(S)
54-
SsqrtVᴴ = lmul!(S, Vᴴ)
55-
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
53+
if !isempty(P)
54+
S .= sqrt.(S)
55+
SsqrtVᴴ = lmul!(S, Vᴴ)
56+
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
57+
end
5658
return (W, P)
5759
end
5860
function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD)
5961
check_input(right_polar!, A, PWᴴ, alg)
60-
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
62+
U, S, Vᴴ = svd_compact!(A, alg.svd_alg)
6163
P, Wᴴ = PWᴴ
6264
Wᴴ = mul!(Wᴴ, U, Vᴴ)
63-
S .= sqrt.(S)
64-
USsqrt = rmul!(U, S)
65-
P = mul!(P, USsqrt, USsqrt')
65+
if !isempty(P)
66+
S .= sqrt.(S)
67+
USsqrt = rmul!(U, S)
68+
P = mul!(P, USsqrt, USsqrt')
69+
end
6670
return (P, Wᴴ)
6771
end
72+
73+
# Implementation via Newton
74+
# --------------------------
75+
function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton)
76+
check_input(left_polar!, A, WP, alg)
77+
W, P = WP
78+
if isempty(P)
79+
W = _left_polarnewton!(A, W, P; alg.kwargs...)
80+
return W, P
81+
else
82+
W = _left_polarnewton!(copy(A), W, P; alg.kwargs...)
83+
# we still need `A` to compute `P`
84+
P = project_hermitian!(mul!(P, W', A))
85+
return W, P
86+
end
87+
end
88+
89+
function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarNewton)
90+
check_input(right_polar!, A, PWᴴ, alg)
91+
P, Wᴴ = PWᴴ
92+
if isempty(P)
93+
Wᴴ = _right_polarnewton!(A, Wᴴ, P; alg.kwargs...)
94+
return P, Wᴴ
95+
else
96+
Wᴴ = _right_polarnewton!(copy(A), Wᴴ, P; alg.kwargs...)
97+
# we still need `A` to compute `P`
98+
P = project_hermitian!(mul!(P, A, Wᴴ'))
99+
return P, Wᴴ
100+
end
101+
end
102+
103+
# these methods only compute W and destroy A in the process
104+
function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
105+
m, n = size(A) # we must have m >= n
106+
Rᴴinv = isempty(P) ? similar(P, (n, n)) : P # use P as workspace when available
107+
if m > n # initial QR
108+
Q, R = qr_compact!(A)
109+
Rc = view(A, 1:n, 1:n)
110+
copy!(Rc, R)
111+
Rᴴinv = ldiv!(UpperTriangular(Rc)', one!(Rᴴinv))
112+
else # m == n
113+
R = A
114+
Rc = view(W, 1:n, 1:n)
115+
copy!(Rc, R)
116+
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
117+
end
118+
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
119+
rmul!(R, γ)
120+
rmul!(Rᴴinv, 1 / γ)
121+
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
122+
copy!(Rc, R)
123+
i = 1
124+
conv = norm(Rᴴinv, Inf)
125+
while i < maxiter && conv > tol
126+
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
127+
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
128+
rmul!(R, γ)
129+
rmul!(Rᴴinv, 1 / γ)
130+
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
131+
copy!(Rc, R)
132+
conv = norm(Rᴴinv, Inf)
133+
i += 1
134+
end
135+
if conv > tol
136+
@warn "`left_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv)"
137+
end
138+
if m > n
139+
return mul!(W, Q, Rc)
140+
end
141+
return W
142+
end
143+
144+
function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
145+
m, n = size(A) # we must have m <= n
146+
Lᴴinv = isempty(P) ? similar(P, (m, m)) : P # use P as workspace when available
147+
if m < n # initial QR
148+
L, Q = lq_compact!(A)
149+
Lc = view(A, 1:m, 1:m)
150+
copy!(Lc, L)
151+
Lᴴinv = ldiv!(LowerTriangular(Lc)', one!(Lᴴinv))
152+
else # m == n
153+
L = A
154+
Lc = view(Wᴴ, 1:m, 1:m)
155+
copy!(Lc, L)
156+
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
157+
end
158+
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
159+
rmul!(L, γ)
160+
rmul!(Lᴴinv, 1 / γ)
161+
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
162+
copy!(Lc, L)
163+
i = 1
164+
conv = norm(Lᴴinv, Inf)
165+
while i < maxiter && conv > tol
166+
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
167+
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
168+
rmul!(L, γ)
169+
rmul!(Lᴴinv, 1 / γ)
170+
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
171+
copy!(Lc, L)
172+
conv = norm(Lᴴinv, Inf)
173+
i += 1
174+
end
175+
if conv > tol
176+
@warn "`right_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv)"
177+
end
178+
if m < n
179+
return mul!(Wᴴ, Lc, Q)
180+
end
181+
return Wᴴ
182+
end
183+
184+
# in place computation of the average and difference of two arrays
185+
function _avgdiff!(A::AbstractArray, B::AbstractArray)
186+
axes(A) == axes(B) || throw(DimensionMismatch())
187+
@simd for I in eachindex(A, B)
188+
@inbounds begin
189+
a = A[I]
190+
b = B[I]
191+
A[I] = (a + b) / 2
192+
B[I] = b - a
193+
end
194+
end
195+
return A, B
196+
end

src/implementations/projections.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ function copy_input(::typeof(project_hermitian), A::AbstractMatrix)
55
end
66
copy_input(::typeof(project_antihermitian), A) = copy_input(project_hermitian, A)
77

8+
copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A)
9+
810
function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
911
LinearAlgebra.checksquare(A)
1012
n = Base.require_one_based_indexing(A)
@@ -18,6 +20,16 @@ function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::Abs
1820
return nothing
1921
end
2022

23+
function check_input(::typeof(project_isometric!), A::AbstractMatrix, W::AbstractMatrix, ::AbstractAlgorithm)
24+
m, n = size(A)
25+
m >= n ||
26+
throw(ArgumentError("input matrix needs at least as many rows as columns"))
27+
@assert W isa AbstractMatrix
28+
@check_size(W, (m, n))
29+
@check_scalar(W, A)
30+
return nothing
31+
end
32+
2133
# Outputs
2234
# -------
2335
function initialize_output(::typeof(project_hermitian!), A::AbstractMatrix, ::NativeBlocked)
@@ -27,15 +39,26 @@ function initialize_output(::typeof(project_antihermitian!), A::AbstractMatrix,
2739
return A
2840
end
2941

42+
function initialize_output(::typeof(project_isometric!), A::AbstractMatrix, ::AbstractAlgorithm)
43+
return similar(A)
44+
end
45+
3046
# Implementation
3147
# --------------
32-
function project_hermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
33-
check_input(project_hermitian!, A, B, alg)
34-
return project_hermitian_native!(A, B, Val(false); alg.kwargs...)
48+
function project_hermitian!(A::AbstractMatrix, Aₕ, alg::NativeBlocked)
49+
check_input(project_hermitian!, A, Aₕ, alg)
50+
return project_hermitian_native!(A, Aₕ, Val(false); alg.kwargs...)
3551
end
36-
function project_antihermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
37-
check_input(project_antihermitian!, A, B, alg)
38-
return project_hermitian_native!(A, B, Val(true); alg.kwargs...)
52+
function project_antihermitian!(A::AbstractMatrix, Aₐ, alg::NativeBlocked)
53+
check_input(project_antihermitian!, A, Aₐ, alg)
54+
return project_hermitian_native!(A, Aₐ, Val(true); alg.kwargs...)
55+
end
56+
57+
function project_isometric!(A::AbstractMatrix, W, alg::AbstractAlgorithm)
58+
check_input(project_isometric!, A, W, alg)
59+
noP = similar(W, (0, 0))
60+
W, _ = left_polar!(A, (W, noP), alg)
61+
return W
3962
end
4063

4164
function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)

src/interface/decompositions.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,28 @@ const LAPACK_SVDAlgorithm = Union{
117117
LAPACK_Jacobi,
118118
}
119119

120+
# =========================
121+
# Polar decompositions
122+
# =========================
123+
"""
124+
PolarViaSVD(svd_alg)
125+
126+
Algorithm for computing the polar decomposition of a matrix `A` via the singular value
127+
decomposition (SVD) of `A`. The `svd_alg` argument specifies the SVD algorithm to use.
128+
"""
129+
struct PolarViaSVD{SVDAlg} <: AbstractAlgorithm
130+
svd_alg::SVDAlg
131+
end
132+
133+
"""
134+
PolarNewton(; maxiter = 10, tol = defaulttol(A))
135+
136+
Algorithm for computing the polar decomposition of a matrix `A` via
137+
scaled Newton iteration, with a maximum of `maxiter` iterations and
138+
until convergence up to tolerance `tol`.
139+
"""
140+
@algdef PolarNewton
141+
120142
# =========================
121143
# DIAGONAL ALGORITHMS
122144
# =========================

src/interface/polar.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,6 @@ See also [`left_polar(!)`](@ref left_polar).
3939
"""
4040
@functiondef right_polar
4141

42-
"""
43-
PolarViaSVD(svdalg)
44-
45-
Algorithm for computing the polar decomposition of a matrix `A` via the singular value
46-
decomposition (SVD) of `A`. The `svdalg` argument specifies the SVD algorithm to use.
47-
"""
48-
struct PolarViaSVD{SVDAlg} <: AbstractAlgorithm
49-
svdalg::SVDAlg
50-
end
51-
5242
# Algorithm selection
5343
# -------------------
5444
default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...)

src/interface/projections.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
@doc """
22
project_hermitian(A; kwargs...)
33
project_hermitian(A, alg)
4-
project_hermitian!(A; kwargs...)
5-
project_hermitian!(A, alg)
4+
project_hermitian!(A, [Aₕ]; kwargs...)
5+
project_hermitian!(A, [Aₕ], alg)
66
77
Compute the hermitian part of a (square) matrix `A`, defined as `(A + A') / 2`.
8-
For real matrices this corresponds to the symmetric part of `A`.
8+
For real matrices this corresponds to the symmetric part of `A`. In the bang method,
9+
the output storage can be provided via the optional argument `Aₕ`; by default it is
10+
equal to `A` and so the input matrix `A` is replaced by its hermitian projection.
911
1012
See also [`project_antihermitian`](@ref).
1113
"""
@@ -14,16 +16,36 @@ See also [`project_antihermitian`](@ref).
1416
@doc """
1517
project_antihermitian(A; kwargs...)
1618
project_antihermitian(A, alg)
17-
project_antihermitian!(A; kwargs...)
18-
project_antihermitian!(A, alg)
19+
project_antihermitian!(A, [Aₐ]; kwargs...)
20+
project_antihermitian!(A, [Aₐ], alg)
1921
2022
Compute the anti-hermitian part of a (square) matrix `A`, defined as `(A - A') / 2`.
21-
For real matrices this corresponds to the antisymmetric part of `A`.
23+
For real matrices this corresponds to the antisymmetric part of `A`. In the bang method,
24+
the output storage can be provided via the optional argument `Aₐ``; by default it is
25+
equal to `A` and so the input matrix `A` is replaced by its antihermitian projection.
2226
2327
See also [`project_hermitian`](@ref).
2428
"""
2529
@functiondef project_antihermitian
2630

31+
@doc """
32+
project_isometric(A; kwargs...)
33+
project_isometric(A, alg)
34+
project_isometric!(A, [W]; kwargs...)
35+
project_isometric!(A, [W], alg)
36+
37+
Compute the projection of `A` onto the manifold of isometric matrices, i.e. matrices
38+
satisfying `A' * A ≈ I`. This projection is computed via the polar decomposition, i.e.
39+
`W` corresponds to the first return value of `left_polar!`, but avoids computing the
40+
positive definite factor explicitly.
41+
42+
!!! note
43+
The bang method `project_isometric!` optionally accepts the output structure and
44+
possibly destroys the input matrix `A`. Always use the return value of the function
45+
as it may not always be possible to use the provided `W` as output.
46+
"""
47+
@functiondef project_isometric
48+
2749
"""
2850
NativeBlocked(; blocksize = 32)
2951
@@ -43,3 +65,6 @@ for f in (:project_hermitian!, :project_antihermitian!)
4365
return default_hermitian_algorithm(A; kwargs...)
4466
end
4567
end
68+
69+
default_algorithm(::typeof(project_isometric!), ::Type{A}; kwargs...) where {A} =
70+
default_polar_algorithm(A; kwargs...)

0 commit comments

Comments
 (0)