Skip to content

Commit f523144

Browse files
committed
Fix tests
1 parent d8a546d commit f523144

File tree

4 files changed

+21
-15
lines changed

4 files changed

+21
-15
lines changed

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value)
5757
return a
5858
end
5959

60+
# Catch zero-dimensional case to avoid scalar indexing.
61+
function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value, ::Block{0})
62+
blocks(a)[] = value
63+
return a
64+
end
65+
6066
function Base.setindex!(
6167
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
6268
) where {N}

src/abstractblocksparsearray/map.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ArrayLayouts: LayoutArray
22
using BlockArrays: blockisequal
33
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
4+
using GPUArraysCore: @allowscalar
45
using LinearAlgebra: Adjoint, Transpose
56
using SparseArraysBase: SparseArraysBase, SparseArrayStyle
67

@@ -55,7 +56,7 @@ function map_zero_dim! end
5556
@interface ::AbstractArrayInterface function map_zero_dim!(
5657
f, a_dest::AbstractArray, a_srcs::AbstractArray...
5758
)
58-
a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
59+
@allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
5960
return a_dest
6061
end
6162

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ using BlockArrays:
1515
blocks,
1616
findblockindex
1717
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface
18-
using GPUArraysCore: @allowscalar
1918
using LinearAlgebra: Adjoint, Transpose
2019
using SparseArraysBase:
2120
AbstractSparseArrayInterface,
@@ -66,7 +65,7 @@ end
6665
@interface ::AbstractBlockSparseArrayInterface function Base.getindex(
6766
a::AbstractArray{<:Any,0}
6867
)
69-
return @allowscalar a[Block()[]]
68+
return a[Block()[]]
7069
end
7170

7271
# a[1:2, 1:2]
@@ -157,7 +156,7 @@ end
157156
)
158157
a_b = blocks(a)[]
159158
# `value[]` handles scalars and 0-dimensional arrays.
160-
@allowscalar a_b[] = value[]
159+
a_b[] = value[]
161160
# Set the block, required if it is structurally zero.
162161
blocks(a)[] = a_b
163162
return a

test/test_basics.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,17 @@ arrayts = (Array, JLArray)
183183
@test blocksize(a) == ()
184184
@test blocksizes(a) == fill(())
185185
@test iszero(blockstoredlength(a))
186-
@test iszero(a[])
187-
@test iszero([CartesianIndex()])
186+
@test iszero(@allowscalar(a[]))
187+
@test iszero(@allowscalar(a[CartesianIndex()]))
188188
@test a[Block()] == dev(fill(0))
189-
@test iszero(a[Block()][])
190-
@test iszero(a[Block()[]])
189+
@test iszero(@allowscalar(a[Block()][]))
190+
@test iszero(@allowscalar(a[Block()[]]))
191191
@test Array(a) isa Array{elt,0}
192192
@test Array(a) == fill(0)
193193
for b in (
194-
(b = copy(a); b[] = 2; b),
195-
(b = copy(a); b[CartesianIndex()] = 2; b),
196-
(b = copy(a); b[CartesianIndex()] = 2; b),
194+
(b = copy(a); @allowscalar(b[] = 2); b),
195+
(b = copy(a); @allowscalar(b[CartesianIndex()] = 2); b),
196+
(b = copy(a); @allowscalar(b[Block()[]] = 2); b),
197197
# Regression test for https://github.com/ITensor/BlockSparseArrays.jl/issues/27.
198198
(b = copy(a); b[Block()] = dev(fill(2)); b),
199199
)
@@ -202,11 +202,11 @@ arrayts = (Array, JLArray)
202202
@test blocksize(b) == ()
203203
@test blocksizes(b) == fill(())
204204
@test isone(blockstoredlength(b))
205-
@test b[] == 2
206-
@test b[CartesianIndex()] == 2
205+
@test @allowscalar(b[]) == 2
206+
@test @allowscalar(b[CartesianIndex()]) == 2
207207
@test b[Block()] == dev(fill(2))
208-
@test b[Block()][] == 2
209-
@test b[Block()[]] == 2
208+
@test @allowscalar(b[Block()][]) == 2
209+
@test @allowscalar(b[Block()[]]) == 2
210210
@test Array(b) isa Array{elt,0}
211211
@test Array(b) == fill(2)
212212
end

0 commit comments

Comments
 (0)