Skip to content

Commit 1b918eb

Browse files
committed
Better implementation of left_orth/right_orth
1 parent 784c469 commit 1b918eb

File tree

4 files changed

+103
-5
lines changed

4 files changed

+103
-5
lines changed

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
4747
include("factorizations/svd.jl")
4848
include("factorizations/truncation.jl")
4949
include("factorizations/qr.jl")
50+
include("factorizations/orthnull.jl")
5051

5152
end

src/factorizations/orthnull.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
using MatrixAlgebraKit:
2+
MatrixAlgebraKit, left_polar, qr_compact, select_algorithm, svd_compact
3+
4+
function MatrixAlgebraKit.left_orth(
5+
A::AbstractBlockSparseMatrix;
6+
trunc=nothing,
7+
kind=isnothing(trunc) ? :qr : :svd,
8+
alg_qr=(; positive=true),
9+
alg_polar=(;),
10+
alg_svd=(;),
11+
)
12+
if !isnothing(trunc) && kind != :svd
13+
throw(ArgumentError("truncation not supported for `left_orth` with `kind=$kind`"))
14+
end
15+
if kind == :qr
16+
return left_orth_qr(A, alg_qr)
17+
elseif kind == :polar
18+
# TODO: Implement this.
19+
# return left_orth_polar(A, alg_polar)
20+
return left_orth_svd(A, alg_svd)
21+
elseif kind == :svd
22+
return left_orth_svd(A, alg_svd, trunc)
23+
else
24+
throw(ArgumentError("`left_orth` received unknown value `kind = $kind`"))
25+
end
26+
end
27+
function left_orth_qr(A, alg)
28+
alg′ = select_algorithm(qr_compact, A, alg)
29+
return qr_compact(A, alg′)
30+
end
31+
function left_orth_polar(A, alg)
32+
alg′ = select_algorithm(left_polar, A, alg)
33+
return left_polar(A, alg′)
34+
end
35+
function left_orth_svd(A, alg, trunc::Nothing=nothing)
36+
alg′ = select_algorithm(svd_compact, A, alg)
37+
U, S, Vᴴ = svd_compact(A, alg′)
38+
return U, S * Vᴴ
39+
end
40+
41+
function MatrixAlgebraKit.right_orth(
42+
A;
43+
trunc=nothing,
44+
kind=isnothing(trunc) ? :lq : :svd,
45+
alg_lq=(; positive=true),
46+
alg_polar=(;),
47+
alg_svd=(;),
48+
)
49+
if !isnothing(trunc) && kind != :svd
50+
throw(ArgumentError("truncation not supported for `right_orth` with `kind=$kind`"))
51+
end
52+
if kind == :qr
53+
# TODO: Implement this.
54+
# return right_orth_lq(A, alg_lq)
55+
return right_orth_svd(A, alg_svd)
56+
elseif kind == :polar
57+
# TODO: Implement this.
58+
# return right_orth_polar(A, alg_polar)
59+
return right_orth_svd(A, alg_svd)
60+
elseif kind == :svd
61+
return right_orth_svd(A, alg_svd, trunc)
62+
else
63+
throw(ArgumentError("`right_orth` received unknown value `kind = $kind`"))
64+
end
65+
end
66+
function right_orth_lq(A, alg)
67+
alg′ = select_algorithm(lq_compact, A, alg)
68+
return lq_compact(A, alg′)
69+
end
70+
function right_orth_polar(A, alg)
71+
alg′ = select_algorithm(right_polar, A, alg)
72+
return right_polar(A, alg′)
73+
end
74+
function right_orth_svd(A, alg, trunc::Nothing=nothing)
75+
alg′ = select_algorithm(svd_compact, A, alg)
76+
U, S, Vᴴ = svd_compact(A, alg′)
77+
return U * S, Vᴴ
78+
end
79+
function right_orth_svd(A, alg, trunc)
80+
alg′ = select_algorithm(svd_compact, A, alg)
81+
alg_trunc = select_algorithm(svd_trunc, A, alg′; trunc)
82+
U, S, Vᴴ = svd_trunc(A, alg_trunc)
83+
return U * S, Vᴴ
84+
end

src/factorizations/qr.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using MatrixAlgebraKit: MatrixAlgebraKit, qr_compact!, qr_full!
1+
using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full!, qr_compact!, qr_full!
22

33
# TODO: this is a hardcoded for now to get around this function not being defined in the
44
# type domain
@@ -19,6 +19,23 @@ function MatrixAlgebraKit.default_algorithm(
1919
return default_blocksparse_qr_algorithm(A; kwargs...)
2020
end
2121

22+
function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...)
23+
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
24+
error("unsupported type: $(blocktype(A))")
25+
alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...)
26+
return BlockPermutedDiagonalAlgorithm(alg)
27+
end
28+
function MatrixAlgebraKit.default_algorithm(
29+
::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs...
30+
)
31+
return default_blocksparse_lq_algorithm(A; kwargs...)
32+
end
33+
function MatrixAlgebraKit.default_algorithm(
34+
::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs...
35+
)
36+
return default_blocksparse_lq_algorithm(A; kwargs...)
37+
end
38+
2239
function similar_output(
2340
::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
2441
)

src/factorizations/svd.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,3 @@ function MatrixAlgebraKit.svd_full!(
265265

266266
return USVᴴ
267267
end
268-
269-
function MatrixAlgebraKit.left_orth_svd!(A::AbstractBlockSparseMatrix, VC, alg, trunc)
270-
return @invoke MatrixAlgebraKit.left_orth_svd!(A::Any, VC::Any, alg::Any, trunc::Any)
271-
end

0 commit comments

Comments
 (0)