Skip to content

Implement left_orth/right_orth #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.6.2"
version = "0.6.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 2 additions & 0 deletions src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,7 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
include("factorizations/svd.jl")
include("factorizations/truncation.jl")
include("factorizations/qr.jl")
include("factorizations/polar.jl")
include("factorizations/orthnull.jl")

end
96 changes: 96 additions & 0 deletions src/factorizations/orthnull.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
using MatrixAlgebraKit:
MatrixAlgebraKit,
left_orth_polar!,
left_orth_qr!,
left_orth_svd!,
left_polar!,
lq_compact!,
qr_compact!,
right_orth_lq!,
right_orth_polar!,
right_orth_svd!,
right_polar!,
select_algorithm,
svd_compact!

function MatrixAlgebraKit.left_orth!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not caught by the MatrixAlgebraKit implementation? If not, it probably should?

Copy link
Member Author

@mtfishman mtfishman May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit subtle, the MatrixAlgebraKit.jl version is designed around the output being pre-allocated, so it has an extra argument for the output. That is tricky to support for block sparse matrices since the different backends (SVD, QR/LQ, polar) might have different block structures/sectors so it isn't straightforward to pre-allocate the output in that way.

We can move this code over to MatrixAlgebraKit.jl, that was the main thing I wanted your feedback on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see now! This was indeed something that I at some point saw in MatrixAlgebraKit but then kind of forgot about in the end.
My initial reaction would be that I do think this should indeed be moved over there, such that the general codepath follows that of the @algdef-defined functions, with a little extra implementations that make it hard to simply use @algdef.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, in that case I'll make a PR to MatrixAlgebraKit.jl, it would be nice to not have this logic here.

A::AbstractBlockSparseMatrix;
trunc=nothing,
kind=isnothing(trunc) ? :qr : :svd,
alg_qr=(; positive=true),
alg_polar=(;),
alg_svd=(;),
)
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for `left_orth` with `kind=$kind`"))
end
if kind == :qr
return left_orth_qr!(A, alg_qr)
elseif kind == :polar
return left_orth_polar!(A, alg_polar)
elseif kind == :svd
return left_orth_svd!(A, alg_svd, trunc)
else
throw(ArgumentError("`left_orth` received unknown value `kind = $kind`"))
end
end
function MatrixAlgebraKit.left_orth_qr!(A::AbstractBlockSparseMatrix, alg)
alg′ = select_algorithm(qr_compact!, A, alg)
return qr_compact!(A, alg′)
end
function MatrixAlgebraKit.left_orth_polar!(A::AbstractBlockSparseMatrix, alg)
alg′ = select_algorithm(left_polar!, A, alg)
return left_polar!(A, alg′)
end
function MatrixAlgebraKit.left_orth_svd!(
A::AbstractBlockSparseMatrix, alg, trunc::Nothing=nothing
)
alg′ = select_algorithm(svd_compact!, A, alg)
U, S, Vᴴ = svd_compact!(A, alg′)
return U, S * Vᴴ
end

function MatrixAlgebraKit.right_orth!(
A::AbstractBlockSparseMatrix;
trunc=nothing,
kind=isnothing(trunc) ? :lq : :svd,
alg_lq=(; positive=true),
alg_polar=(;),
alg_svd=(;),
)
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for `right_orth` with `kind=$kind`"))
end
if kind == :qr
# TODO: Implement this.
# return right_orth_lq!(A, alg_lq)
return right_orth_svd!(A, alg_svd)
elseif kind == :polar
return right_orth_polar!(A, alg_polar)
elseif kind == :svd
return right_orth_svd!(A, alg_svd, trunc)
else
throw(ArgumentError("`right_orth` received unknown value `kind = $kind`"))
end
end
function MatrixAlgebraKit.right_orth_lq!(A::AbstractBlockSparseMatrix, alg)
alg′ = select_algorithm(lq_compact, A, alg)
return lq_compact!(A, alg′)
end
function MatrixAlgebraKit.right_orth_polar!(A::AbstractBlockSparseMatrix, alg)
alg′ = select_algorithm(right_polar!, A, alg)
return right_polar!(A, alg′)
end
function MatrixAlgebraKit.right_orth_svd!(
A::AbstractBlockSparseMatrix, alg, trunc::Nothing=nothing
)
alg′ = select_algorithm(svd_compact!, A, alg)
U, S, Vᴴ = svd_compact!(A, alg′)
return U * S, Vᴴ
end
function MatrixAlgebraKit.right_orth_svd!(A::AbstractBlockSparseMatrix, alg, trunc)
alg′ = select_algorithm(svd_compact!, A, alg)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
U, S, Vᴴ = svd_trunc!(A, alg_trunc)
return U * S, Vᴴ
end
53 changes: 53 additions & 0 deletions src/factorizations/polar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using MatrixAlgebraKit:
MatrixAlgebraKit,
PolarViaSVD,
check_input,
default_algorithm,
left_polar!,
right_polar!,
svd_compact!

function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractBlockSparseMatrix)
@views for I in eachblockstoredindex(A)
m, n = size(A[I])
m >= n ||
throw(ArgumentError("each input matrix block needs at least as many rows as columns"))
end
return nothing
end
function MatrixAlgebraKit.check_input(::typeof(right_polar!), A::AbstractBlockSparseMatrix)
@views for I in eachblockstoredindex(A)
m, n = size(A[I])
m <= n ||
throw(ArgumentError("each input matrix block needs at least as many columns as rows"))
end
return nothing
end

function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD)
check_input(left_polar!, A)
# TODO: Use more in-place operations here, avoid `copy`.
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
W = U * Vᴴ
P = copy(Vᴴ') * S * Vᴴ
return (W, P)
end
function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD)
check_input(right_polar!, A)
# TODO: Use more in-place operations here, avoid `copy`.
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
Wᴴ = U * Vᴴ
P = U * S * copy(U')
return (P, Wᴴ)
end

function MatrixAlgebraKit.default_algorithm(
::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs...
)
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
end
function MatrixAlgebraKit.default_algorithm(
::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs...
)
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
end
19 changes: 18 additions & 1 deletion src/factorizations/qr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MatrixAlgebraKit: MatrixAlgebraKit, qr_compact!, qr_full!
using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full!, qr_compact!, qr_full!

# TODO: this is a hardcoded for now to get around this function not being defined in the
# type domain
Expand All @@ -19,6 +19,23 @@ function MatrixAlgebraKit.default_algorithm(
return default_blocksparse_qr_algorithm(A; kwargs...)
end

function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...)
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
error("unsupported type: $(blocktype(A))")
alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...)
return BlockPermutedDiagonalAlgorithm(alg)
end
function MatrixAlgebraKit.default_algorithm(
::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs...
)
return default_blocksparse_lq_algorithm(A; kwargs...)
end
function MatrixAlgebraKit.default_algorithm(
::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs...
)
return default_blocksparse_lq_algorithm(A; kwargs...)
end

function similar_output(
::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
)
Expand Down
60 changes: 57 additions & 3 deletions test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
using MatrixAlgebraKit:
qr_compact, qr_full, svd_compact, svd_full, svd_trunc, truncrank, trunctol
left_orth,
left_polar,
qr_compact,
qr_full,
right_orth,
right_polar,
svd_compact,
svd_full,
svd_trunc,
truncrank,
trunctol
using LinearAlgebra: LinearAlgebra
using Random: Random
using Test: @inferred, @testset, @test
Expand Down Expand Up @@ -156,7 +166,7 @@ end
end
end

@testset "qr_compact" for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "qr_compact (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
for i in [1, 2], j in [1, 2], k in [1, 2], l in [1, 2]
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
A[Block(1, 1)] = randn(T, i, k)
Expand All @@ -167,7 +177,7 @@ end
end
end

@testset "qr_full" for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "qr_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
A[Block(1, 1)] = randn(T, i, k)
Expand All @@ -181,3 +191,47 @@ end
@test A ≈ Q * R
end
end

@testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
A[Block(1, 1)] = randn(T, 3, 2)
A[Block(2, 2)] = randn(T, 4, 3)

U, C = left_polar(A)
@test U * C ≈ A
@test Matrix(U'U) ≈ LinearAlgebra.I
end

@testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
A[Block(1, 1)] = randn(T, 2, 3)
A[Block(2, 2)] = randn(T, 3, 4)

C, U = right_polar(A)
@test C * U ≈ A
@test Matrix(U * U') ≈ LinearAlgebra.I
end

@testset "left_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
A[Block(1, 1)] = randn(T, 3, 2)
A[Block(2, 2)] = randn(T, 4, 3)

for kind in (:qr, :polar, :svd)
U, C = left_orth(A; kind)
@test U * C ≈ A
@test Matrix(U'U) ≈ LinearAlgebra.I
end
end

@testset "right_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
A[Block(1, 1)] = randn(T, 2, 3)
A[Block(2, 2)] = randn(T, 3, 4)

for kind in (:qr, :polar, :svd)
C, U = right_orth(A; kind)
@test C * U ≈ A
@test Matrix(U * U') ≈ LinearAlgebra.I
end
end
Loading