Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec

### Fixed

- Polar decompositions return exact hermitian factors ([#143](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/143)

## [0.6.1](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.0...v0.6.1) - 2025-12-28

### Added
Expand Down
8 changes: 8 additions & 0 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,12 @@ function MatrixAlgebraKit.truncate(
return Vᴴ[ind, :], ind
end

# avoids calling the BlasMat specialization that assumes syrk! or herk! is called
# TODO: remove once syrk! or herk! is defined
function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix{T}) where {T <: BlasFloat}
mul!(C, A, A')
project_hermitian!(C)
return C
end

end
8 changes: 8 additions & 0 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,12 @@ function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
return A, B
end

# avoids calling the BlasMat specialization that assumes syrk! or herk! is called
# TODO: remove once syrk! or herk! is defined
function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T}) where {T <: BlasFloat}
mul!(C, A, A')
project_hermitian!(C)
return C
end

end
17 changes: 15 additions & 2 deletions src/implementations/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD)
if !isempty(P)
S .= sqrt.(S)
SsqrtVᴴ = lmul!(S, Vᴴ)
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
P = _mul_herm!(P, SsqrtVᴴ')
end
return (W, P)
end
Expand All @@ -65,11 +65,24 @@ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD)
if !isempty(P)
S .= sqrt.(S)
USsqrt = rmul!(U, S)
P = mul!(P, USsqrt, USsqrt')
P = _mul_herm!(P, USsqrt)
end
return (P, Wᴴ)
end

# Implement `mul!(C, A', A)` and guarantee the result is hermitian.
# For BLAS calls that dispatch to `syrk` or `herk` this works automatically
# for GPU this currently does not seem to be guaranteed so we manually project
function _mul_herm!(C, A)
mul!(C, A, A')
project_hermitian!(C)
return C
end
function _mul_herm!(C::YALAPACK.BlasMat{T}, A::YALAPACK.BlasMat{T}) where {T <: YALAPACK.BlasFloat}
mul!(C, A, A')
return C
end

# Implementation via Newton
# --------------------------
function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton)
Expand Down
54 changes: 24 additions & 30 deletions test/testsuite/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,30 @@ function test_left_polar(
)
summary_str = testargs_summary(T, sz)
return @testset "left_polar! algorithm $alg $summary_str" for alg in algs
@testset "algorithm $alg" for alg in algs
A = instantiate_matrix(T, sz)
Ac = deepcopy(A)
W, P = left_polar(A; alg)
@test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2))
@test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2))
@test W * P ≈ A
@test isisometric(W)
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
@test isposdef(project_hermitian!(P))
A = instantiate_matrix(T, sz)
Ac = deepcopy(A)
W, P = left_polar(A; alg)
@test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2))
@test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2))
@test W * P ≈ A
@test isisometric(W)
@test isposdef(P)

W2, P2 = @testinferred left_polar!(Ac, (W, P), alg)
@test W2 === W
@test P2 === P
@test W * P ≈ A
@test isisometric(W)
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
@test isposdef(project_hermitian!(P))
W2, P2 = @testinferred left_polar!(Ac, (W, P), alg)
@test W2 === W
@test P2 === P
@test W * P ≈ A
@test isisometric(W)
@test isposdef(P)

noP = similar(P, (0, 0))
W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg)
@test P2 === noP
@test W2 === W
@test isisometric(W)
P = W' * A # compute P explicitly to verify W correctness
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
@test isposdef(project_hermitian!(P))
end
noP = similar(P, (0, 0))
W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg)
@test P2 === noP
@test W2 === W
@test isisometric(W)
P = W' * A # compute P explicitly to verify W correctness
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
@test isposdef(project_hermitian!(P))
end
end

Expand All @@ -62,16 +58,14 @@ function test_right_polar(
@test eltype(P) == eltype(A) && size(P) == (size(A, 1), size(A, 1))
@test P * Wᴴ ≈ A
@test isisometric(Wᴴ; side = :right)
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
@test isposdef(project_hermitian!(P))
@test isposdef(P)

P2, Wᴴ2 = @testinferred right_polar!(Ac, (P, Wᴴ), alg)
@test P2 === P
@test Wᴴ2 === Wᴴ
@test P * Wᴴ ≈ A
@test isisometric(Wᴴ; side = :right)
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
@test isposdef(project_hermitian!(P))
@test isposdef(P)

noP = similar(P, (0, 0))
P2, Wᴴ2 = @testinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)
Expand Down