diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 8ce4f956..ac4c8438 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ### Fixed +- Polar decompositions return exact hermitian factors ([#143](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/143) + ## [0.6.1](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.0...v0.6.1) - 2025-12-28 ### Added diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index ff150f24..fd0c1604 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -198,4 +198,12 @@ function MatrixAlgebraKit.truncate( return Vᴴ[ind, :], ind end +# avoids calling the BlasMat specialization that assumes syrk! or herk! is called +# TODO: remove once syrk! or herk! is defined +function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix{T}) where {T <: BlasFloat} + mul!(C, A, A') + project_hermitian!(C) + return C +end + end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 4d34dd9e..e3acb553 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -183,4 +183,12 @@ function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix) return A, B end +# avoids calling the BlasMat specialization that assumes syrk! or herk! is called +# TODO: remove once syrk! or herk! is defined +function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T}) where {T <: BlasFloat} + mul!(C, A, A') + project_hermitian!(C) + return C +end + end diff --git a/src/implementations/polar.jl b/src/implementations/polar.jl index b94b2afa..00f9cbbd 100644 --- a/src/implementations/polar.jl +++ b/src/implementations/polar.jl @@ -53,7 +53,7 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD) if !isempty(P) S .= sqrt.(S) SsqrtVᴴ = lmul!(S, Vᴴ) - P = mul!(P, SsqrtVᴴ', SsqrtVᴴ) + P = _mul_herm!(P, SsqrtVᴴ') end return (W, P) end @@ -65,11 +65,24 @@ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD) if !isempty(P) S .= sqrt.(S) USsqrt = rmul!(U, S) - P = mul!(P, USsqrt, USsqrt') + P = _mul_herm!(P, USsqrt) end return (P, Wᴴ) end +# Implement `mul!(C, A', A)` and guarantee the result is hermitian. +# For BLAS calls that dispatch to `syrk` or `herk` this works automatically +# for GPU this currently does not seem to be guaranteed so we manually project +function _mul_herm!(C, A) + mul!(C, A, A') + project_hermitian!(C) + return C +end +function _mul_herm!(C::YALAPACK.BlasMat{T}, A::YALAPACK.BlasMat{T}) where {T <: YALAPACK.BlasFloat} + mul!(C, A, A') + return C +end + # Implementation via Newton # -------------------------- function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton) diff --git a/test/testsuite/polar.jl b/test/testsuite/polar.jl index c610ba34..d858c0c1 100644 --- a/test/testsuite/polar.jl +++ b/test/testsuite/polar.jl @@ -17,34 +17,30 @@ function test_left_polar( ) summary_str = testargs_summary(T, sz) return @testset "left_polar! algorithm $alg $summary_str" for alg in algs - @testset "algorithm $alg" for alg in algs - A = instantiate_matrix(T, sz) - Ac = deepcopy(A) - W, P = left_polar(A; alg) - @test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2)) - @test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2)) - @test W * P ≈ A - @test isisometric(W) - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + W, P = left_polar(A; alg) + @test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2)) + @test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2)) + @test W * P ≈ A + @test isisometric(W) + @test isposdef(P) - W2, P2 = @testinferred left_polar!(Ac, (W, P), alg) - @test W2 === W - @test P2 === P - @test W * P ≈ A - @test isisometric(W) - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) + W2, P2 = @testinferred left_polar!(Ac, (W, P), alg) + @test W2 === W + @test P2 === P + @test W * P ≈ A + @test isisometric(W) + @test isposdef(P) - noP = similar(P, (0, 0)) - W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg) - @test P2 === noP - @test W2 === W - @test isisometric(W) - P = W' * A # compute P explicitly to verify W correctness - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) - end + noP = similar(P, (0, 0)) + W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg) + @test P2 === noP + @test W2 === W + @test isisometric(W) + P = W' * A # compute P explicitly to verify W correctness + @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) + @test isposdef(project_hermitian!(P)) end end @@ -62,16 +58,14 @@ function test_right_polar( @test eltype(P) == eltype(A) && size(P) == (size(A, 1), size(A, 1)) @test P * Wᴴ ≈ A @test isisometric(Wᴴ; side = :right) - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) + @test isposdef(P) P2, Wᴴ2 = @testinferred right_polar!(Ac, (P, Wᴴ), alg) @test P2 === P @test Wᴴ2 === Wᴴ @test P * Wᴴ ≈ A @test isisometric(Wᴴ; side = :right) - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) + @test isposdef(P) noP = similar(P, (0, 0)) P2, Wᴴ2 = @testinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)