Skip to content

Commit 37fecf7

Browse files
authored
Fix broken SVD tests (#33)
* reenable broken tests * use `isstored` instead of `I in eachstoredindex` * Bump version and minimal SparseArraysBase compat * Formatter
1 parent 60ec997 commit 37fecf7

File tree

3 files changed

+16
-28
lines changed

3 files changed

+16
-28
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.10"
4+
version = "0.2.11"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -42,7 +42,7 @@ LabelledNumbers = "0.1.0"
4242
LinearAlgebra = "1.10"
4343
MacroTools = "0.5.13"
4444
MapBroadcast = "0.1.5"
45-
SparseArraysBase = "0.2.2"
45+
SparseArraysBase = "0.2.10"
4646
SplitApplyCombine = "1.2.3"
4747
TensorAlgebra = "0.1.0"
4848
Test = "1.10"

src/abstractblocksparsearray/views.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,8 @@ function BlockArrays.viewblock(
6868
a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N}
6969
) where {N}
7070
I = CartesianIndex(Int.(block))
71-
# TODO: Use `eachblockstoredindex`.
72-
if I eachstoredindex(blocks(a))
73-
return blocks(a)[I]
74-
end
75-
return BlockView(a, block)
71+
# TODO: isblockstored
72+
return isstored(blocks(a), I) ? blocks(a)[I] : BlockView(a, block)
7673
end
7774

7875
# Specialized code for getting the view of a subblock.

test/test_svd.jl

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ using LinearAlgebra: LinearAlgebra
55
using Random: Random
66
using Test: @inferred, @testset, @test
77

8-
function test_svd(a, usv; broken=false)
8+
function test_svd(a, usv)
99
U, S, V = usv
10-
@test U * diagonal(S) * V' a broken = broken
11-
@test U' * U LinearAlgebra.I
12-
@test V' * V LinearAlgebra.I
10+
return (U * diagonal(S) * V' a) &&
11+
(U' * U LinearAlgebra.I) &&
12+
(V' * V LinearAlgebra.I)
1313
end
1414

1515
# regular matrix
@@ -19,7 +19,7 @@ eltypes = (Float32, Float64, ComplexF64)
1919
@testset "($m, $n) Matrix{$T}" for ((m, n), T) in Iterators.product(sizes, eltypes)
2020
a = rand(m, n)
2121
usv = @inferred svd(a)
22-
test_svd(a, usv)
22+
@test test_svd(a, usv)
2323
end
2424

2525
# block matrix
@@ -28,7 +28,7 @@ blockszs = (([2, 2], [2, 2]), ([2, 2], [2, 3]), ([2, 2, 1], [2, 3]), ([2, 3], [2
2828
@testset "($m, $n) BlockMatrix{$T}" for ((m, n), T) in Iterators.product(blockszs, eltypes)
2929
a = mortar([rand(T, i, j) for i in m, j in n])
3030
usv = svd(a)
31-
test_svd(a, usv)
31+
@test test_svd(a, usv)
3232
@test usv.U isa BlockedMatrix
3333
@test usv.Vt isa BlockedMatrix
3434
@test usv.S isa BlockedVector
@@ -39,17 +39,8 @@ end
3939
@testset "($m, $n) BlockDiagonal{$T}" for ((m, n), T) in
4040
Iterators.product(blockszs, eltypes)
4141
a = BlockDiagonal([rand(T, i, j) for (i, j) in zip(m, n)])
42-
if VERSION v"1.11"
43-
usv = svd(a)
44-
# TODO: `BlockDiagonal * Adjoint` errors
45-
# TODO: This is broken because of https://github.com/JuliaLang/julia/issues/57034,
46-
# fix and reenable.
47-
test_svd(a, usv; broken=true)
48-
else
49-
# `svd(a)` depends on `diagind(::AbstractMatrix, ::IndexStyle)`
50-
# being defined, but it was only introduced in Julia v1.11.
51-
@test svd(a) broken = true
52-
end
42+
usv = svd(a)
43+
@test test_svd(a, usv)
5344
end
5445

5546
# blocksparse
@@ -60,25 +51,25 @@ end
6051

6152
# test empty matrix
6253
usv_empty = svd(a)
63-
test_svd(a, usv_empty)
54+
@test test_svd(a, usv_empty)
6455

6556
# test blockdiagonal
6657
for i in LinearAlgebra.diagind(blocks(a))
6758
I = CartesianIndices(blocks(a))[i]
6859
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
6960
end
7061
usv = svd(a)
71-
test_svd(a, usv)
62+
@test test_svd(a, usv)
7263

7364
perm = Random.randperm(length(m))
7465
b = a[Block.(perm), Block.(1:length(n))]
7566
usv = svd(b)
76-
test_svd(b, usv)
67+
@test test_svd(b, usv)
7768

7869
# test permuted blockdiagonal with missing row/col
7970
I_removed = rand(eachblockstoredindex(b))
8071
c = copy(b)
8172
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
8273
usv = svd(c)
83-
test_svd(c, usv)
74+
@test test_svd(c, usv)
8475
end

0 commit comments

Comments
 (0)