Skip to content

Commit 4b6cb05

Browse files
authored
Guarantee hermitian results in polar decompositions (#143)
* guarantee hermitian results in polar * update tests * fix double `@testset` * change to `syrk` convention * update changelog * some reshuffling to actually correctly dispatch
1 parent 3f1c86a commit 4b6cb05

File tree

5 files changed

+57
-32
lines changed

5 files changed

+57
-32
lines changed

docs/src/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec
3030

3131
### Fixed
3232

33+
- Polar decompositions return exact hermitian factors ([#143](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/143)
34+
3335
## [0.6.1](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.0...v0.6.1) - 2025-12-28
3436

3537
### Added

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,4 +198,12 @@ function MatrixAlgebraKit.truncate(
198198
return Vᴴ[ind, :], ind
199199
end
200200

201+
# avoids calling the BlasMat specialization that assumes syrk! or herk! is called
202+
# TODO: remove once syrk! or herk! is defined
203+
function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix{T}) where {T <: BlasFloat}
204+
mul!(C, A, A')
205+
project_hermitian!(C)
206+
return C
207+
end
208+
201209
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,12 @@ function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
183183
return A, B
184184
end
185185

186+
# avoids calling the BlasMat specialization that assumes syrk! or herk! is called
187+
# TODO: remove once syrk! or herk! is defined
188+
function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T}) where {T <: BlasFloat}
189+
mul!(C, A, A')
190+
project_hermitian!(C)
191+
return C
192+
end
193+
186194
end

src/implementations/polar.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD)
5353
if !isempty(P)
5454
S .= sqrt.(S)
5555
SsqrtVᴴ = lmul!(S, Vᴴ)
56-
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
56+
P = _mul_herm!(P, SsqrtVᴴ')
5757
end
5858
return (W, P)
5959
end
@@ -65,11 +65,24 @@ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD)
6565
if !isempty(P)
6666
S .= sqrt.(S)
6767
USsqrt = rmul!(U, S)
68-
P = mul!(P, USsqrt, USsqrt')
68+
P = _mul_herm!(P, USsqrt)
6969
end
7070
return (P, Wᴴ)
7171
end
7272

73+
# Implement `mul!(C, A', A)` and guarantee the result is hermitian.
74+
# For BLAS calls that dispatch to `syrk` or `herk` this works automatically
75+
# for GPU this currently does not seem to be guaranteed so we manually project
76+
function _mul_herm!(C, A)
77+
mul!(C, A, A')
78+
project_hermitian!(C)
79+
return C
80+
end
81+
function _mul_herm!(C::YALAPACK.BlasMat{T}, A::YALAPACK.BlasMat{T}) where {T <: YALAPACK.BlasFloat}
82+
mul!(C, A, A')
83+
return C
84+
end
85+
7386
# Implementation via Newton
7487
# --------------------------
7588
function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton)

test/testsuite/polar.jl

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,30 @@ function test_left_polar(
1717
)
1818
summary_str = testargs_summary(T, sz)
1919
return @testset "left_polar! algorithm $alg $summary_str" for alg in algs
20-
@testset "algorithm $alg" for alg in algs
21-
A = instantiate_matrix(T, sz)
22-
Ac = deepcopy(A)
23-
W, P = left_polar(A; alg)
24-
@test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2))
25-
@test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2))
26-
@test W * P A
27-
@test isisometric(W)
28-
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
29-
@test isposdef(project_hermitian!(P))
20+
A = instantiate_matrix(T, sz)
21+
Ac = deepcopy(A)
22+
W, P = left_polar(A; alg)
23+
@test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2))
24+
@test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2))
25+
@test W * P A
26+
@test isisometric(W)
27+
@test isposdef(P)
3028

31-
W2, P2 = @testinferred left_polar!(Ac, (W, P), alg)
32-
@test W2 === W
33-
@test P2 === P
34-
@test W * P A
35-
@test isisometric(W)
36-
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
37-
@test isposdef(project_hermitian!(P))
29+
W2, P2 = @testinferred left_polar!(Ac, (W, P), alg)
30+
@test W2 === W
31+
@test P2 === P
32+
@test W * P A
33+
@test isisometric(W)
34+
@test isposdef(P)
3835

39-
noP = similar(P, (0, 0))
40-
W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg)
41-
@test P2 === noP
42-
@test W2 === W
43-
@test isisometric(W)
44-
P = W' * A # compute P explicitly to verify W correctness
45-
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
46-
@test isposdef(project_hermitian!(P))
47-
end
36+
noP = similar(P, (0, 0))
37+
W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg)
38+
@test P2 === noP
39+
@test W2 === W
40+
@test isisometric(W)
41+
P = W' * A # compute P explicitly to verify W correctness
42+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
43+
@test isposdef(project_hermitian!(P))
4844
end
4945
end
5046

@@ -62,16 +58,14 @@ function test_right_polar(
6258
@test eltype(P) == eltype(A) && size(P) == (size(A, 1), size(A, 1))
6359
@test P * Wᴴ A
6460
@test isisometric(Wᴴ; side = :right)
65-
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
66-
@test isposdef(project_hermitian!(P))
61+
@test isposdef(P)
6762

6863
P2, Wᴴ2 = @testinferred right_polar!(Ac, (P, Wᴴ), alg)
6964
@test P2 === P
7065
@test Wᴴ2 === Wᴴ
7166
@test P * Wᴴ A
7267
@test isisometric(Wᴴ; side = :right)
73-
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
74-
@test isposdef(project_hermitian!(P))
68+
@test isposdef(P)
7569

7670
noP = similar(P, (0, 0))
7771
P2, Wᴴ2 = @testinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)

0 commit comments

Comments
 (0)