Skip to content

Commit f0e3ea2

Browse files
committed
Fix factorizations
1 parent 77c269c commit f0e3ea2

File tree

4 files changed

+37
-18
lines changed

4 files changed

+37
-18
lines changed

src/factorizations/lq.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full!
1+
using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_full!
22

3-
# TODO: this is a hardcoded for now to get around this function not being defined in the
4-
# type domain
5-
function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...)
3+
function MatrixAlgebraKit.default_lq_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
4+
return default_lq_algorithm(typeof(A); kwargs...)
5+
end
6+
function MatrixAlgebraKit.default_lq_algorithm(
7+
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
8+
)
69
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
710
error("unsupported type: $(blocktype(A))")
811
alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...)
912
return BlockPermutedDiagonalAlgorithm(alg)
1013
end
1114
function MatrixAlgebraKit.default_algorithm(
12-
::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs...
15+
::typeof(lq_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1316
)
14-
return default_blocksparse_lq_algorithm(A; kwargs...)
17+
return default_lq_algorithm(A; kwargs...)
1518
end
1619
function MatrixAlgebraKit.default_algorithm(
17-
::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs...
20+
::typeof(lq_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1821
)
19-
return default_blocksparse_lq_algorithm(A; kwargs...)
22+
return default_q_algorithm(A; kwargs...)
2023
end
2124

2225
function similar_output(

src/factorizations/polar.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,22 @@ function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarV
4848
end
4949

5050
function MatrixAlgebraKit.default_algorithm(
51-
::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs...
51+
::typeof(left_polar!), A::AbstractBlockSparseMatrix; kwargs...
5252
)
53-
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
53+
return default_algorithm(left_polar!, typeof(A); kwargs...)
5454
end
5555
function MatrixAlgebraKit.default_algorithm(
56-
::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs...
56+
::typeof(left_polar!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
5757
)
58-
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
58+
return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...))
59+
end
60+
function MatrixAlgebraKit.default_algorithm(
61+
::typeof(right_polar!), A::AbstractBlockSparseMatrix; kwargs...
62+
)
63+
return default_algorithm(right_polar!, typeof(A); kwargs...)
64+
end
65+
function MatrixAlgebraKit.default_algorithm(
66+
::typeof(right_polar!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
67+
)
68+
return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...))
5969
end

src/factorizations/qr.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
using MatrixAlgebraKit:
22
MatrixAlgebraKit, default_qr_algorithm, lq_compact!, lq_full!, qr_compact!, qr_full!
33

4-
# TODO: this is a hardcoded for now to get around this function not being defined in the
5-
# type domain
64
function MatrixAlgebraKit.default_qr_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
5+
return default_qr_algorithm(typeof(A); kwargs...)
6+
end
7+
function MatrixAlgebraKit.default_qr_algorithm(
8+
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
9+
)
710
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
811
error("unsupported type: $(blocktype(A))")
912
alg = MatrixAlgebraKit.LAPACK_HouseholderQR(; kwargs...)
1013
return BlockPermutedDiagonalAlgorithm(alg)
1114
end
1215
function MatrixAlgebraKit.default_algorithm(
13-
::typeof(qr_compact!), A::AbstractBlockSparseMatrix; kwargs...
16+
::typeof(qr_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1417
)
1518
return default_qr_algorithm(A; kwargs...)
1619
end
1720
function MatrixAlgebraKit.default_algorithm(
18-
::typeof(qr_full!), A::AbstractBlockSparseMatrix; kwargs...
21+
::typeof(qr_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1922
)
2023
return default_qr_algorithm(A; kwargs...)
2124
end

src/factorizations/svd.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
1212
alg::A
1313
end
1414

15+
function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
16+
return default_svd_algorithm(typeof(A), kwargs...)
17+
end
1518
function MatrixAlgebraKit.default_svd_algorithm(
1619
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1720
)
@@ -25,12 +28,12 @@ function MatrixAlgebraKit.default_svd_algorithm(
2528
end
2629

2730
function MatrixAlgebraKit.default_algorithm(
28-
f::typeof(svd_compact!), A::AbstractBlockSparseMatrix; kwargs...
31+
f::typeof(svd_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
2932
)
3033
return default_svd_algorithm(A; kwargs...)
3134
end
3235
function MatrixAlgebraKit.default_algorithm(
33-
f::typeof(svd_full!), A::AbstractBlockSparseMatrix; kwargs...
36+
f::typeof(svd_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
3437
)
3538
return default_svd_algorithm(A; kwargs...)
3639
end

0 commit comments

Comments
 (0)