Skip to content

Commit a2cf2f5

Browse files
committed
Add tests for zero dimensional BlockSparseArray
1 parent 60626ae commit a2cf2f5

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ function Base.similar(
220220
return blocksparse_similar(a, elt, axes)
221221
end
222222

223+
# Fixes ambiguity error.
224+
function Base.similar(a::BlockSparseArrayLike{<:Any,0}, elt::Type, axes::Tuple{})
225+
return blocksparse_similar(a, elt, axes)
226+
end
227+
223228
# Fixes ambiguity error with `BlockArrays`.
224229
function Base.similar(
225230
a::BlockSparseArrayLike,

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,37 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
109109

110110
a[3, 3] = NaN
111111
@test isnan(norm(a))
112+
113+
# Empty constructor
114+
for a in (dev(BlockSparseArray{elt}()), dev(BlockSparseArray{elt}(undef)))
115+
@test size(a) == ()
116+
@test isone(length(a))
117+
@test blocksize(a) == ()
118+
@test blocksizes(a) == fill(())
119+
@test iszero(block_nstored(a))
120+
@test iszero(@allowscalar(a[]))
121+
@test iszero(@allowscalar(a[CartesianIndex()]))
122+
@test a[Block()] == dev(fill(0))
123+
@test iszero(@allowscalar(a[Block()][]))
124+
# Broken:
125+
## @test b[Block()[]] == 2
126+
for b in (
127+
(b = copy(a); @allowscalar b[] = 2; b),
128+
(b = copy(a); @allowscalar b[CartesianIndex()] = 2; b),
129+
)
130+
@test size(b) == ()
131+
@test isone(length(b))
132+
@test blocksize(b) == ()
133+
@test blocksizes(b) == fill(())
134+
@test isone(block_nstored(b))
135+
@test @allowscalar(b[]) == 2
136+
@test @allowscalar(b[CartesianIndex()]) == 2
137+
@test b[Block()] == dev(fill(2))
138+
@test @allowscalar(b[Block()][]) == 2
139+
# Broken:
140+
## @test b[Block()[]] == 2
141+
end
142+
end
112143
end
113144
@testset "Tensor algebra" begin
114145
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))

0 commit comments

Comments
 (0)