Skip to content

Commit a5e0c3f

Browse files
committed
Try fixing tests
1 parent f0e3ea2 commit a5e0c3f

File tree

5 files changed

+80
-49
lines changed

5 files changed

+80
-49
lines changed

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ function Base.to_indices(
4545
)
4646
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
4747
end
48+
# Fix ambiguity error with Base for logical indexing in Julia 1.10.
49+
# TODO: Delete this once we drop support for Julia 1.10.
50+
function Base.to_indices(
51+
a::AnyAbstractBlockSparseArray, inds, I::Union{Tuple{BitArray{N}},Tuple{Array{Bool,N}}}
52+
) where {N}
53+
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
54+
end
4855

4956
# a[[Block(2), Block(1)], [Block(2), Block(1)]]
5057
function Base.to_indices(

src/factorizations/lq.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_full!
22

3+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
34
function MatrixAlgebraKit.default_lq_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
45
return default_lq_algorithm(typeof(A); kwargs...)
56
end
6-
function MatrixAlgebraKit.default_lq_algorithm(
7-
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
8-
)
9-
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
10-
error("unsupported type: $(blocktype(A))")
11-
alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...)
12-
return BlockPermutedDiagonalAlgorithm(alg)
13-
end
7+
8+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
149
function MatrixAlgebraKit.default_algorithm(
1510
::typeof(lq_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1611
)
1712
return default_lq_algorithm(A; kwargs...)
1813
end
14+
15+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
1916
function MatrixAlgebraKit.default_algorithm(
2017
::typeof(lq_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
2118
)
22-
return default_q_algorithm(A; kwargs...)
19+
return default_lq_algorithm(A; kwargs...)
20+
end
21+
22+
function MatrixAlgebraKit.default_lq_algorithm(
23+
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
24+
)
25+
alg = default_lq_algorithm(blocktype(A); kwargs...)
26+
return BlockPermutedDiagonalAlgorithm(alg)
2327
end
2428

2529
function similar_output(

src/factorizations/polar.jl

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,34 @@ using MatrixAlgebraKit:
77
right_polar!,
88
svd_compact!
99

10+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
11+
function MatrixAlgebraKit.default_algorithm(
12+
::typeof(left_polar!), A::AbstractBlockSparseMatrix; kwargs...
13+
)
14+
return default_algorithm(left_polar!, typeof(A); kwargs...)
15+
end
16+
17+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
18+
function MatrixAlgebraKit.default_algorithm(
19+
::typeof(left_polar!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
20+
)
21+
return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...))
22+
end
23+
24+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
25+
function MatrixAlgebraKit.default_algorithm(
26+
::typeof(right_polar!), A::AbstractBlockSparseMatrix; kwargs...
27+
)
28+
return default_algorithm(right_polar!, typeof(A); kwargs...)
29+
end
30+
31+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
32+
function MatrixAlgebraKit.default_algorithm(
33+
::typeof(right_polar!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
34+
)
35+
return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...))
36+
end
37+
1038
function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractBlockSparseMatrix)
1139
@views for I in eachblockstoredindex(A)
1240
m, n = size(A[I])
@@ -46,24 +74,3 @@ function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarV
4674
P = U * S * copy(U')
4775
return (P, Wᴴ)
4876
end
49-
50-
function MatrixAlgebraKit.default_algorithm(
51-
::typeof(left_polar!), A::AbstractBlockSparseMatrix; kwargs...
52-
)
53-
return default_algorithm(left_polar!, typeof(A); kwargs...)
54-
end
55-
function MatrixAlgebraKit.default_algorithm(
56-
::typeof(left_polar!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
57-
)
58-
return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...))
59-
end
60-
function MatrixAlgebraKit.default_algorithm(
61-
::typeof(right_polar!), A::AbstractBlockSparseMatrix; kwargs...
62-
)
63-
return default_algorithm(right_polar!, typeof(A); kwargs...)
64-
end
65-
function MatrixAlgebraKit.default_algorithm(
66-
::typeof(right_polar!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
67-
)
68-
return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...))
69-
end

src/factorizations/qr.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
11
using MatrixAlgebraKit:
22
MatrixAlgebraKit, default_qr_algorithm, lq_compact!, lq_full!, qr_compact!, qr_full!
33

4+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
45
function MatrixAlgebraKit.default_qr_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
56
return default_qr_algorithm(typeof(A); kwargs...)
67
end
7-
function MatrixAlgebraKit.default_qr_algorithm(
8-
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
9-
)
10-
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
11-
error("unsupported type: $(blocktype(A))")
12-
alg = MatrixAlgebraKit.LAPACK_HouseholderQR(; kwargs...)
13-
return BlockPermutedDiagonalAlgorithm(alg)
14-
end
8+
9+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
1510
function MatrixAlgebraKit.default_algorithm(
1611
::typeof(qr_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1712
)
1813
return default_qr_algorithm(A; kwargs...)
1914
end
15+
16+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
2017
function MatrixAlgebraKit.default_algorithm(
2118
::typeof(qr_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
2219
)
2320
return default_qr_algorithm(A; kwargs...)
2421
end
2522

23+
function MatrixAlgebraKit.default_qr_algorithm(
24+
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
25+
)
26+
alg = default_qr_algorithm(blocktype(A); kwargs...)
27+
return BlockPermutedDiagonalAlgorithm(alg)
28+
end
29+
2630
function similar_output(
2731
::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
2832
)

src/factorizations/svd.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
using MatrixAlgebraKit: MatrixAlgebraKit, default_svd_algorithm, svd_compact!, svd_full!
22

3+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
4+
using MatrixAlgebraKit: TruncatedAlgorithm, select_truncation, svd_trunc!
5+
function MatrixAlgebraKit.select_algorithm(
6+
::typeof(svd_trunc!), A::Type{<:AbstractBlockSparseMatrix}, alg; trunc=nothing, kwargs...
7+
)
8+
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
9+
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
10+
end
11+
312
"""
413
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)
514
@@ -12,32 +21,32 @@ struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
1221
alg::A
1322
end
1423

24+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
1525
function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
1626
return default_svd_algorithm(typeof(A), kwargs...)
1727
end
18-
function MatrixAlgebraKit.default_svd_algorithm(
19-
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
20-
)
21-
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
22-
error("unsupported type: $(blocktype(A))")
23-
# TODO: this is a hardcoded for now to get around this function not being defined in the
24-
# type domain
25-
# alg = MatrixAlgebraKit.default_algorithm(f, blocktype(A); kwargs...)
26-
alg = MatrixAlgebraKit.LAPACK_DivideAndConquer(; kwargs...)
27-
return BlockPermutedDiagonalAlgorithm(alg)
28-
end
2928

29+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
3030
function MatrixAlgebraKit.default_algorithm(
3131
f::typeof(svd_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
3232
)
3333
return default_svd_algorithm(A; kwargs...)
3434
end
35+
36+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
3537
function MatrixAlgebraKit.default_algorithm(
3638
f::typeof(svd_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
3739
)
3840
return default_svd_algorithm(A; kwargs...)
3941
end
4042

43+
function MatrixAlgebraKit.default_svd_algorithm(
44+
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
45+
)
46+
alg = default_svd_algorithm(blocktype(A); kwargs...)
47+
return BlockPermutedDiagonalAlgorithm(alg)
48+
end
49+
4150
function similar_output(
4251
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
4352
)

0 commit comments

Comments
 (0)