From 784c469155057a7a3a90c56f9fbb326750a83d56 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 26 May 2025 10:06:19 -0400 Subject: [PATCH 1/8] [WIP] Implement left_orth/right_orth --- Project.toml | 2 +- src/factorizations/svd.jl | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fbb1ce2f..7f3b7dfc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.6.2" +version = "0.6.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index fbaae498..ec95d28f 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -265,3 +265,7 @@ function MatrixAlgebraKit.svd_full!( return USVᴴ end + +function MatrixAlgebraKit.left_orth_svd!(A::AbstractBlockSparseMatrix, VC, alg, trunc) + return @invoke MatrixAlgebraKit.left_orth_svd!(A::Any, VC::Any, alg::Any, trunc::Any) +end From 1b918ebc7e0108ffac5a4d273018456c8a460edf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 27 May 2025 11:33:25 -0400 Subject: [PATCH 2/8] Better implementation of left_orth/right_orth --- src/BlockSparseArrays.jl | 1 + src/factorizations/orthnull.jl | 84 ++++++++++++++++++++++++++++++++++ src/factorizations/qr.jl | 19 +++++++- src/factorizations/svd.jl | 4 -- 4 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 src/factorizations/orthnull.jl diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index b7658a5a..9cd0aa25 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -47,5 +47,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl") include("factorizations/svd.jl") include("factorizations/truncation.jl") include("factorizations/qr.jl") +include("factorizations/orthnull.jl") end diff --git a/src/factorizations/orthnull.jl b/src/factorizations/orthnull.jl new file mode 100644 index 00000000..76389dd9 --- /dev/null +++ b/src/factorizations/orthnull.jl @@ -0,0 +1,84 @@ +using MatrixAlgebraKit: + MatrixAlgebraKit, left_polar, qr_compact, select_algorithm, svd_compact + +function MatrixAlgebraKit.left_orth( + A::AbstractBlockSparseMatrix; + trunc=nothing, + kind=isnothing(trunc) ? :qr : :svd, + alg_qr=(; positive=true), + alg_polar=(;), + alg_svd=(;), +) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for `left_orth` with `kind=$kind`")) + end + if kind == :qr + return left_orth_qr(A, alg_qr) + elseif kind == :polar + # TODO: Implement this. + # return left_orth_polar(A, alg_polar) + return left_orth_svd(A, alg_svd) + elseif kind == :svd + return left_orth_svd(A, alg_svd, trunc) + else + throw(ArgumentError("`left_orth` received unknown value `kind = $kind`")) + end +end +function left_orth_qr(A, alg) + alg′ = select_algorithm(qr_compact, A, alg) + return qr_compact(A, alg′) +end +function left_orth_polar(A, alg) + alg′ = select_algorithm(left_polar, A, alg) + return left_polar(A, alg′) +end +function left_orth_svd(A, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact, A, alg) + U, S, Vᴴ = svd_compact(A, alg′) + return U, S * Vᴴ +end + +function MatrixAlgebraKit.right_orth( + A; + trunc=nothing, + kind=isnothing(trunc) ? :lq : :svd, + alg_lq=(; positive=true), + alg_polar=(;), + alg_svd=(;), +) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for `right_orth` with `kind=$kind`")) + end + if kind == :qr + # TODO: Implement this. + # return right_orth_lq(A, alg_lq) + return right_orth_svd(A, alg_svd) + elseif kind == :polar + # TODO: Implement this. + # return right_orth_polar(A, alg_polar) + return right_orth_svd(A, alg_svd) + elseif kind == :svd + return right_orth_svd(A, alg_svd, trunc) + else + throw(ArgumentError("`right_orth` received unknown value `kind = $kind`")) + end +end +function right_orth_lq(A, alg) + alg′ = select_algorithm(lq_compact, A, alg) + return lq_compact(A, alg′) +end +function right_orth_polar(A, alg) + alg′ = select_algorithm(right_polar, A, alg) + return right_polar(A, alg′) +end +function right_orth_svd(A, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact, A, alg) + U, S, Vᴴ = svd_compact(A, alg′) + return U * S, Vᴴ +end +function right_orth_svd(A, alg, trunc) + alg′ = select_algorithm(svd_compact, A, alg) + alg_trunc = select_algorithm(svd_trunc, A, alg′; trunc) + U, S, Vᴴ = svd_trunc(A, alg_trunc) + return U * S, Vᴴ +end diff --git a/src/factorizations/qr.jl b/src/factorizations/qr.jl index c28da925..03a992bc 100644 --- a/src/factorizations/qr.jl +++ b/src/factorizations/qr.jl @@ -1,4 +1,4 @@ -using MatrixAlgebraKit: MatrixAlgebraKit, qr_compact!, qr_full! +using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full!, qr_compact!, qr_full! # TODO: this is a hardcoded for now to get around this function not being defined in the # type domain @@ -19,6 +19,23 @@ function MatrixAlgebraKit.default_algorithm( return default_blocksparse_qr_algorithm(A; kwargs...) end +function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...) + blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} || + error("unsupported type: $(blocktype(A))") + alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) +end +function MatrixAlgebraKit.default_algorithm( + ::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs... +) + return default_blocksparse_lq_algorithm(A; kwargs...) +end +function MatrixAlgebraKit.default_algorithm( + ::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs... +) + return default_blocksparse_lq_algorithm(A; kwargs...) +end + function similar_output( ::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm ) diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index ec95d28f..fbaae498 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -265,7 +265,3 @@ function MatrixAlgebraKit.svd_full!( return USVᴴ end - -function MatrixAlgebraKit.left_orth_svd!(A::AbstractBlockSparseMatrix, VC, alg, trunc) - return @invoke MatrixAlgebraKit.left_orth_svd!(A::Any, VC::Any, alg::Any, trunc::Any) -end From d4699b96db36defcbef00f12d86f2791edfc3abf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 27 May 2025 13:03:27 -0400 Subject: [PATCH 3/8] Add polar, add tests --- src/BlockSparseArrays.jl | 1 + src/factorizations/orthnull.jl | 18 +++++----- src/factorizations/polar.jl | 49 +++++++++++++++++++++++++++ test/test_factorizations.jl | 60 ++++++++++++++++++++++++++++++++-- 4 files changed, 117 insertions(+), 11 deletions(-) create mode 100644 src/factorizations/polar.jl diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index 9cd0aa25..3059cc19 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -47,6 +47,7 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl") include("factorizations/svd.jl") include("factorizations/truncation.jl") include("factorizations/qr.jl") +include("factorizations/polar.jl") include("factorizations/orthnull.jl") end diff --git a/src/factorizations/orthnull.jl b/src/factorizations/orthnull.jl index 76389dd9..33665e2f 100644 --- a/src/factorizations/orthnull.jl +++ b/src/factorizations/orthnull.jl @@ -1,5 +1,11 @@ using MatrixAlgebraKit: - MatrixAlgebraKit, left_polar, qr_compact, select_algorithm, svd_compact + MatrixAlgebraKit, + left_polar, + lq_compact, + qr_compact, + right_polar, + select_algorithm, + svd_compact function MatrixAlgebraKit.left_orth( A::AbstractBlockSparseMatrix; @@ -15,9 +21,7 @@ function MatrixAlgebraKit.left_orth( if kind == :qr return left_orth_qr(A, alg_qr) elseif kind == :polar - # TODO: Implement this. - # return left_orth_polar(A, alg_polar) - return left_orth_svd(A, alg_svd) + return left_orth_polar(A, alg_polar) elseif kind == :svd return left_orth_svd(A, alg_svd, trunc) else @@ -39,7 +43,7 @@ function left_orth_svd(A, alg, trunc::Nothing=nothing) end function MatrixAlgebraKit.right_orth( - A; + A::AbstractBlockSparseMatrix; trunc=nothing, kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), @@ -54,9 +58,7 @@ function MatrixAlgebraKit.right_orth( # return right_orth_lq(A, alg_lq) return right_orth_svd(A, alg_svd) elseif kind == :polar - # TODO: Implement this. - # return right_orth_polar(A, alg_polar) - return right_orth_svd(A, alg_svd) + return right_orth_polar(A, alg_polar) elseif kind == :svd return right_orth_svd(A, alg_svd, trunc) else diff --git a/src/factorizations/polar.jl b/src/factorizations/polar.jl new file mode 100644 index 00000000..c364d562 --- /dev/null +++ b/src/factorizations/polar.jl @@ -0,0 +1,49 @@ +using MatrixAlgebraKit: + MatrixAlgebraKit, + PolarViaSVD, + check_input, + default_algorithm, + left_polar!, + right_polar!, + svd_compact! + +function MatrixAlgebraKit.check_input( + ::typeof(left_polar!), A::AbstractBlockSparseMatrix, WP +) + W, P = WP + @views for I in eachblockstoredindex(A) + m, n = size(A[I]) + m >= n || + throw(ArgumentError("each input matrix block needs at least as many rows as columns")) + # check_input(left_polar!, A[I], (W[I1], P[I2])) + end + return nothing +end + +function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, WP, alg::PolarViaSVD) + check_input(left_polar!, A, WP) + U, S, Vᴴ = svd_compact!(A, alg.svdalg) + # TODO: Use more in-place operations here, avoid `copy`. + W = U * Vᴴ + P = copy(Vᴴ') * S * Vᴴ + return (W, P) +end +function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, PWᴴ, alg::PolarViaSVD) + check_input(right_polar!, A, PWᴴ) + U, S, Vᴴ = svd_compact!(A, alg.svdalg) + # TODO: Use more in-place operations here, avoid `copy`. + Wᴴ = U * Vᴴ + P = U * S * copy(U') + return (P, Wᴴ) +end + +function MatrixAlgebraKit.default_algorithm( + ::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs... +) + return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...)) +end +function MatrixAlgebraKit.default_algorithm( + ::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs... +) + return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...)) +end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 02b42f86..f3fc96e8 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,7 +1,17 @@ using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex using MatrixAlgebraKit: - qr_compact, qr_full, svd_compact, svd_full, svd_trunc, truncrank, trunctol + left_orth, + left_polar, + qr_compact, + qr_full, + right_orth, + right_polar, + svd_compact, + svd_full, + svd_trunc, + truncrank, + trunctol using LinearAlgebra: LinearAlgebra using Random: Random using Test: @inferred, @testset, @test @@ -156,7 +166,7 @@ end end end -@testset "qr_compact" for T in (Float32, Float64, ComplexF32, ComplexF64) +@testset "qr_compact (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) for i in [1, 2], j in [1, 2], k in [1, 2], l in [1, 2] A = BlockSparseArray{T}(undef, ([i, j], [k, l])) A[Block(1, 1)] = randn(T, i, k) @@ -167,7 +177,7 @@ end end end -@testset "qr_full" for T in (Float32, Float64, ComplexF32, ComplexF64) +@testset "qr_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3] A = BlockSparseArray{T}(undef, ([i, j], [k, l])) A[Block(1, 1)] = randn(T, i, k) @@ -181,3 +191,47 @@ end @test A ≈ Q * R end end + +@testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([3, 4], [2, 3])) + A[Block(1, 1)] = randn(T, 3, 2) + A[Block(2, 2)] = randn(T, 4, 3) + + U, C = left_polar(A) + @test U * C ≈ A + @test Matrix(U'U) ≈ LinearAlgebra.I +end + +@testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [3, 4])) + A[Block(1, 1)] = randn(T, 2, 3) + A[Block(2, 2)] = randn(T, 3, 4) + + C, U = right_polar(A) + @test C * U ≈ A + @test Matrix(U * U') ≈ LinearAlgebra.I +end + +@testset "left_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([3, 4], [2, 3])) + A[Block(1, 1)] = randn(T, 3, 2) + A[Block(2, 2)] = randn(T, 4, 3) + + for kind in (:qr, :polar, :svd) + U, C = left_orth(A; kind) + @test U * C ≈ A + @test Matrix(U'U) ≈ LinearAlgebra.I + end +end + +@testset "right_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = BlockSparseArray{T}(undef, ([2, 3], [3, 4])) + A[Block(1, 1)] = randn(T, 2, 3) + A[Block(2, 2)] = randn(T, 3, 4) + + for kind in (:qr, :polar, :svd) + C, U = right_orth(A; kind) + @test C * U ≈ A + @test Matrix(U * U') ≈ LinearAlgebra.I + end +end From 969b49ec475545740f1c6c9faa37345832a556fa Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 27 May 2025 13:52:56 -0400 Subject: [PATCH 4/8] Improve polar a bit --- src/factorizations/polar.jl | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/factorizations/polar.jl b/src/factorizations/polar.jl index c364d562..9192f011 100644 --- a/src/factorizations/polar.jl +++ b/src/factorizations/polar.jl @@ -7,31 +7,35 @@ using MatrixAlgebraKit: right_polar!, svd_compact! -function MatrixAlgebraKit.check_input( - ::typeof(left_polar!), A::AbstractBlockSparseMatrix, WP -) - W, P = WP +function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractBlockSparseMatrix) @views for I in eachblockstoredindex(A) m, n = size(A[I]) m >= n || throw(ArgumentError("each input matrix block needs at least as many rows as columns")) - # check_input(left_polar!, A[I], (W[I1], P[I2])) + end + return nothing +end +function MatrixAlgebraKit.check_input(::typeof(right_polar!), A::AbstractBlockSparseMatrix) + @views for I in eachblockstoredindex(A) + m, n = size(A[I]) + m <= n || + throw(ArgumentError("each input matrix block needs at least as many columns as rows")) end return nothing end -function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, WP, alg::PolarViaSVD) - check_input(left_polar!, A, WP) - U, S, Vᴴ = svd_compact!(A, alg.svdalg) +function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD) + check_input(left_polar!, A) # TODO: Use more in-place operations here, avoid `copy`. + U, S, Vᴴ = svd_compact!(A, alg.svdalg) W = U * Vᴴ P = copy(Vᴴ') * S * Vᴴ return (W, P) end -function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, PWᴴ, alg::PolarViaSVD) - check_input(right_polar!, A, PWᴴ) - U, S, Vᴴ = svd_compact!(A, alg.svdalg) +function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD) + check_input(right_polar!, A) # TODO: Use more in-place operations here, avoid `copy`. + U, S, Vᴴ = svd_compact!(A, alg.svdalg) Wᴴ = U * Vᴴ P = U * S * copy(U') return (P, Wᴴ) From ce437875397d6370e185310642ae670a33481837 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 27 May 2025 14:16:56 -0400 Subject: [PATCH 5/8] More in-place --- src/factorizations/orthnull.jl | 80 +++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/src/factorizations/orthnull.jl b/src/factorizations/orthnull.jl index 33665e2f..29cb8449 100644 --- a/src/factorizations/orthnull.jl +++ b/src/factorizations/orthnull.jl @@ -1,13 +1,19 @@ using MatrixAlgebraKit: MatrixAlgebraKit, - left_polar, - lq_compact, - qr_compact, - right_polar, + left_orth_polar!, + left_orth_qr!, + left_orth_svd!, + left_polar!, + lq_compact!, + qr_compact!, + right_orth_lq!, + right_orth_polar!, + right_orth_svd!, + right_polar!, select_algorithm, - svd_compact + svd_compact! -function MatrixAlgebraKit.left_orth( +function MatrixAlgebraKit.left_orth!( A::AbstractBlockSparseMatrix; trunc=nothing, kind=isnothing(trunc) ? :qr : :svd, @@ -19,30 +25,32 @@ function MatrixAlgebraKit.left_orth( throw(ArgumentError("truncation not supported for `left_orth` with `kind=$kind`")) end if kind == :qr - return left_orth_qr(A, alg_qr) + return left_orth_qr!(A, alg_qr) elseif kind == :polar - return left_orth_polar(A, alg_polar) + return left_orth_polar!(A, alg_polar) elseif kind == :svd - return left_orth_svd(A, alg_svd, trunc) + return left_orth_svd!(A, alg_svd, trunc) else throw(ArgumentError("`left_orth` received unknown value `kind = $kind`")) end end -function left_orth_qr(A, alg) - alg′ = select_algorithm(qr_compact, A, alg) - return qr_compact(A, alg′) +function MatrixAlgebraKit.left_orth_qr!(A::AbstractBlockSparseMatrix, alg) + alg′ = select_algorithm(qr_compact!, A, alg) + return qr_compact!(A, alg′) end -function left_orth_polar(A, alg) - alg′ = select_algorithm(left_polar, A, alg) - return left_polar(A, alg′) +function MatrixAlgebraKit.left_orth_polar!(A::AbstractBlockSparseMatrix, alg) + alg′ = select_algorithm(left_polar!, A, alg) + return left_polar!(A, alg′) end -function left_orth_svd(A, alg, trunc::Nothing=nothing) - alg′ = select_algorithm(svd_compact, A, alg) - U, S, Vᴴ = svd_compact(A, alg′) +function MatrixAlgebraKit.left_orth_svd!( + A::AbstractBlockSparseMatrix, alg, trunc::Nothing=nothing +) + alg′ = select_algorithm(svd_compact!, A, alg) + U, S, Vᴴ = svd_compact!(A, alg′) return U, S * Vᴴ end -function MatrixAlgebraKit.right_orth( +function MatrixAlgebraKit.right_orth!( A::AbstractBlockSparseMatrix; trunc=nothing, kind=isnothing(trunc) ? :lq : :svd, @@ -55,32 +63,34 @@ function MatrixAlgebraKit.right_orth( end if kind == :qr # TODO: Implement this. - # return right_orth_lq(A, alg_lq) - return right_orth_svd(A, alg_svd) + # return right_orth_lq!(A, alg_lq) + return right_orth_svd!(A, alg_svd) elseif kind == :polar - return right_orth_polar(A, alg_polar) + return right_orth_polar!(A, alg_polar) elseif kind == :svd - return right_orth_svd(A, alg_svd, trunc) + return right_orth_svd!(A, alg_svd, trunc) else throw(ArgumentError("`right_orth` received unknown value `kind = $kind`")) end end -function right_orth_lq(A, alg) +function MatrixAlgebraKit.right_orth_lq!(A::AbstractBlockSparseMatrix, alg) alg′ = select_algorithm(lq_compact, A, alg) - return lq_compact(A, alg′) + return lq_compact!(A, alg′) end -function right_orth_polar(A, alg) - alg′ = select_algorithm(right_polar, A, alg) - return right_polar(A, alg′) +function MatrixAlgebraKit.right_orth_polar!(A::AbstractBlockSparseMatrix, alg) + alg′ = select_algorithm(right_polar!, A, alg) + return right_polar!(A, alg′) end -function right_orth_svd(A, alg, trunc::Nothing=nothing) - alg′ = select_algorithm(svd_compact, A, alg) - U, S, Vᴴ = svd_compact(A, alg′) +function MatrixAlgebraKit.right_orth_svd!( + A::AbstractBlockSparseMatrix, alg, trunc::Nothing=nothing +) + alg′ = select_algorithm(svd_compact!, A, alg) + U, S, Vᴴ = svd_compact!(A, alg′) return U * S, Vᴴ end -function right_orth_svd(A, alg, trunc) - alg′ = select_algorithm(svd_compact, A, alg) - alg_trunc = select_algorithm(svd_trunc, A, alg′; trunc) - U, S, Vᴴ = svd_trunc(A, alg_trunc) +function MatrixAlgebraKit.right_orth_svd!(A::AbstractBlockSparseMatrix, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) + U, S, Vᴴ = svd_trunc!(A, alg_trunc) return U * S, Vᴴ end From 825c81673ce6a061165b32b19e776ecd6a6f1fa3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 27 May 2025 20:58:25 -0400 Subject: [PATCH 6/8] Simplify implementation --- src/factorizations/orthnull.jl | 82 ++++++++++++---------------------- src/factorizations/qr.jl | 17 ------- test/test_factorizations.jl | 4 +- 3 files changed, 30 insertions(+), 73 deletions(-) diff --git a/src/factorizations/orthnull.jl b/src/factorizations/orthnull.jl index 29cb8449..bae249ec 100644 --- a/src/factorizations/orthnull.jl +++ b/src/factorizations/orthnull.jl @@ -1,94 +1,68 @@ using MatrixAlgebraKit: MatrixAlgebraKit, - left_orth_polar!, - left_orth_qr!, - left_orth_svd!, + left_orth!, left_polar!, lq_compact!, qr_compact!, - right_orth_lq!, - right_orth_polar!, - right_orth_svd!, + right_orth!, right_polar!, select_algorithm, svd_compact! -function MatrixAlgebraKit.left_orth!( - A::AbstractBlockSparseMatrix; - trunc=nothing, - kind=isnothing(trunc) ? :qr : :svd, - alg_qr=(; positive=true), - alg_polar=(;), - alg_svd=(;), +function MatrixAlgebraKit.initialize_output( + ::typeof(left_orth!), A::AbstractBlockSparseMatrix ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for `left_orth` with `kind=$kind`")) - end - if kind == :qr - return left_orth_qr!(A, alg_qr) - elseif kind == :polar - return left_orth_polar!(A, alg_polar) - elseif kind == :svd - return left_orth_svd!(A, alg_svd, trunc) - else - throw(ArgumentError("`left_orth` received unknown value `kind = $kind`")) - end + return nothing end -function MatrixAlgebraKit.left_orth_qr!(A::AbstractBlockSparseMatrix, alg) +function MatrixAlgebraKit.check_input( + ::typeof(left_orth!), A::AbstractBlockSparseMatrix, F::Nothing +) + return nothing +end + +function MatrixAlgebraKit.left_orth_qr!(A::AbstractBlockSparseMatrix, F, alg) alg′ = select_algorithm(qr_compact!, A, alg) return qr_compact!(A, alg′) end -function MatrixAlgebraKit.left_orth_polar!(A::AbstractBlockSparseMatrix, alg) +function MatrixAlgebraKit.left_orth_polar!(A::AbstractBlockSparseMatrix, F, alg) alg′ = select_algorithm(left_polar!, A, alg) return left_polar!(A, alg′) end function MatrixAlgebraKit.left_orth_svd!( - A::AbstractBlockSparseMatrix, alg, trunc::Nothing=nothing + A::AbstractBlockSparseMatrix, F, alg, trunc::Nothing=nothing ) alg′ = select_algorithm(svd_compact!, A, alg) U, S, Vᴴ = svd_compact!(A, alg′) return U, S * Vᴴ end -function MatrixAlgebraKit.right_orth!( - A::AbstractBlockSparseMatrix; - trunc=nothing, - kind=isnothing(trunc) ? :lq : :svd, - alg_lq=(; positive=true), - alg_polar=(;), - alg_svd=(;), +function MatrixAlgebraKit.initialize_output( + ::typeof(right_orth!), A::AbstractBlockSparseMatrix ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for `right_orth` with `kind=$kind`")) - end - if kind == :qr - # TODO: Implement this. - # return right_orth_lq!(A, alg_lq) - return right_orth_svd!(A, alg_svd) - elseif kind == :polar - return right_orth_polar!(A, alg_polar) - elseif kind == :svd - return right_orth_svd!(A, alg_svd, trunc) - else - throw(ArgumentError("`right_orth` received unknown value `kind = $kind`")) - end + return nothing end -function MatrixAlgebraKit.right_orth_lq!(A::AbstractBlockSparseMatrix, alg) - alg′ = select_algorithm(lq_compact, A, alg) +function MatrixAlgebraKit.check_input( + ::typeof(right_orth!), A::AbstractBlockSparseMatrix, F::Nothing +) + return nothing +end + +function MatrixAlgebraKit.right_orth_lq!(A::AbstractBlockSparseMatrix, F, alg) + alg′ = select_algorithm(lq_compact!, A, alg) return lq_compact!(A, alg′) end -function MatrixAlgebraKit.right_orth_polar!(A::AbstractBlockSparseMatrix, alg) +function MatrixAlgebraKit.right_orth_polar!(A::AbstractBlockSparseMatrix, F, alg) alg′ = select_algorithm(right_polar!, A, alg) return right_polar!(A, alg′) end function MatrixAlgebraKit.right_orth_svd!( - A::AbstractBlockSparseMatrix, alg, trunc::Nothing=nothing + A::AbstractBlockSparseMatrix, F, alg, trunc::Nothing=nothing ) alg′ = select_algorithm(svd_compact!, A, alg) U, S, Vᴴ = svd_compact!(A, alg′) return U * S, Vᴴ end -function MatrixAlgebraKit.right_orth_svd!(A::AbstractBlockSparseMatrix, alg, trunc) +function MatrixAlgebraKit.right_orth_svd!(A::AbstractBlockSparseMatrix, F, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) U, S, Vᴴ = svd_trunc!(A, alg_trunc) diff --git a/src/factorizations/qr.jl b/src/factorizations/qr.jl index 03a992bc..8ec1ccdf 100644 --- a/src/factorizations/qr.jl +++ b/src/factorizations/qr.jl @@ -19,23 +19,6 @@ function MatrixAlgebraKit.default_algorithm( return default_blocksparse_qr_algorithm(A; kwargs...) end -function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...) - blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} || - error("unsupported type: $(blocktype(A))") - alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) -end -function MatrixAlgebraKit.default_algorithm( - ::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs... -) - return default_blocksparse_lq_algorithm(A; kwargs...) -end -function MatrixAlgebraKit.default_algorithm( - ::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs... -) - return default_blocksparse_lq_algorithm(A; kwargs...) -end - function similar_output( ::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm ) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 367bb74e..cd5f5015 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -250,7 +250,7 @@ end A[Block(1, 1)] = randn(T, 3, 2) A[Block(2, 2)] = randn(T, 4, 3) - for kind in (:qr, :polar, :svd) + for kind in (:polar, :qr, :svd) U, C = left_orth(A; kind) @test U * C ≈ A @test Matrix(U'U) ≈ LinearAlgebra.I @@ -262,7 +262,7 @@ end A[Block(1, 1)] = randn(T, 2, 3) A[Block(2, 2)] = randn(T, 3, 4) - for kind in (:qr, :polar, :svd) + for kind in (:lq, :polar, :svd) C, U = right_orth(A; kind) @test C * U ≈ A @test Matrix(U * U') ≈ LinearAlgebra.I From bd696cd0706069eddcaee62b083b2f85cfd41138 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 28 May 2025 07:20:47 -0400 Subject: [PATCH 7/8] Simplify test imports --- test/test_factorizations.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index cd5f5015..ab175900 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,11 +1,6 @@ using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex using MatrixAlgebraKit: - left_orth, - left_polar, - qr_compact, - qr_full, - right_orth, left_orth, left_polar, lq_compact, From be112c8ca65cb60da9d2382e341479b7f027617e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 28 May 2025 07:33:35 -0400 Subject: [PATCH 8/8] Stricter checks, fix left_orth trunc, test trunc --- src/factorizations/orthnull.jl | 60 ++++++++++++++++++++++++++++++++-- test/test_factorizations.jl | 10 ++++++ 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/src/factorizations/orthnull.jl b/src/factorizations/orthnull.jl index bae249ec..b06af8e3 100644 --- a/src/factorizations/orthnull.jl +++ b/src/factorizations/orthnull.jl @@ -14,27 +14,56 @@ function MatrixAlgebraKit.initialize_output( ) return nothing end -function MatrixAlgebraKit.check_input( - ::typeof(left_orth!), A::AbstractBlockSparseMatrix, F::Nothing -) +function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::AbstractBlockSparseMatrix, F) + !isnothing(F) && throw( + ArgumentError( + "`left_orth!` on block sparse matrices does not support specifying the output" + ), + ) return nothing end function MatrixAlgebraKit.left_orth_qr!(A::AbstractBlockSparseMatrix, F, alg) + !isnothing(F) && throw( + ArgumentError( + "`left_orth!` on block sparse matrices does not support specifying the output" + ), + ) alg′ = select_algorithm(qr_compact!, A, alg) return qr_compact!(A, alg′) end function MatrixAlgebraKit.left_orth_polar!(A::AbstractBlockSparseMatrix, F, alg) + !isnothing(F) && throw( + ArgumentError( + "`left_orth!` on block sparse matrices does not support specifying the output" + ), + ) alg′ = select_algorithm(left_polar!, A, alg) return left_polar!(A, alg′) end function MatrixAlgebraKit.left_orth_svd!( A::AbstractBlockSparseMatrix, F, alg, trunc::Nothing=nothing ) + !isnothing(F) && throw( + ArgumentError( + "`left_orth!` on block sparse matrices does not support specifying the output" + ), + ) alg′ = select_algorithm(svd_compact!, A, alg) U, S, Vᴴ = svd_compact!(A, alg′) return U, S * Vᴴ end +function MatrixAlgebraKit.left_orth_svd!(A::AbstractBlockSparseMatrix, F, alg, trunc) + !isnothing(F) && throw( + ArgumentError( + "`left_orth!` on block sparse matrices does not support specifying the output" + ), + ) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) + U, S, Vᴴ = svd_trunc!(A, alg_trunc) + return U, S * Vᴴ +end function MatrixAlgebraKit.initialize_output( ::typeof(right_orth!), A::AbstractBlockSparseMatrix @@ -44,25 +73,50 @@ end function MatrixAlgebraKit.check_input( ::typeof(right_orth!), A::AbstractBlockSparseMatrix, F::Nothing ) + !isnothing(F) && throw( + ArgumentError( + "`right_orth!` on block sparse matrices does not support specifying the output" + ), + ) return nothing end function MatrixAlgebraKit.right_orth_lq!(A::AbstractBlockSparseMatrix, F, alg) + !isnothing(F) && throw( + ArgumentError( + "`right_orth!` on block sparse matrices does not support specifying the output" + ), + ) alg′ = select_algorithm(lq_compact!, A, alg) return lq_compact!(A, alg′) end function MatrixAlgebraKit.right_orth_polar!(A::AbstractBlockSparseMatrix, F, alg) + !isnothing(F) && throw( + ArgumentError( + "`right_orth!` on block sparse matrices does not support specifying the output" + ), + ) alg′ = select_algorithm(right_polar!, A, alg) return right_polar!(A, alg′) end function MatrixAlgebraKit.right_orth_svd!( A::AbstractBlockSparseMatrix, F, alg, trunc::Nothing=nothing ) + !isnothing(F) && throw( + ArgumentError( + "`right_orth!` on block sparse matrices does not support specifying the output" + ), + ) alg′ = select_algorithm(svd_compact!, A, alg) U, S, Vᴴ = svd_compact!(A, alg′) return U * S, Vᴴ end function MatrixAlgebraKit.right_orth_svd!(A::AbstractBlockSparseMatrix, F, alg, trunc) + !isnothing(F) && throw( + ArgumentError( + "`right_orth!` on block sparse matrices does not support specifying the output" + ), + ) alg′ = select_algorithm(svd_compact!, A, alg) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) U, S, Vᴴ = svd_trunc!(A, alg_trunc) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index ab175900..2ae4ebce 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -250,6 +250,11 @@ end @test U * C ≈ A @test Matrix(U'U) ≈ LinearAlgebra.I end + + U, C = left_orth(A; trunc=(; maxrank=2)) + @test size(U, 2) ≤ 2 + @test size(C, 1) ≤ 2 + @test Matrix(U'U) ≈ LinearAlgebra.I end @testset "right_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -262,4 +267,9 @@ end @test C * U ≈ A @test Matrix(U * U') ≈ LinearAlgebra.I end + + C, U = right_orth(A; trunc=(; maxrank=2)) + @test size(C, 2) ≤ 2 + @test size(U, 1) ≤ 2 + @test Matrix(U * U') ≈ LinearAlgebra.I end