Skip to content

Commit 34ba9ef

Browse files
authored
[BlockSparseArrays] Fix initializing blocks using broadcasting (#1506)
* [BlockSparseArrays] Fix initializing blocks using broadcasting * [NDTensors] Bump to v0.3.34
1 parent 6244905 commit 34ba9ef

File tree

4 files changed

+84
-167
lines changed

4 files changed

+84
-167
lines changed

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,22 @@ function Base.setindex!(a::BlockSparseArrayLike{<:Any,1}, value, I::Block{1})
185185
return a
186186
end
187187

188+
function Base.fill!(a::AbstractBlockSparseArray, value)
189+
if iszero(value)
190+
# This drops all of the blocks.
191+
sparse_zero!(blocks(a))
192+
return a
193+
end
194+
blocksparse_fill!(a, value)
195+
return a
196+
end
197+
188198
function Base.fill!(a::BlockSparseArrayLike, value)
199+
# TODO: Even if `iszero(value)`, this doesn't drop
200+
# blocks from `a`, and additionally allocates
201+
# new blocks filled with zeros, unlike
202+
# `fill!(a::AbstractBlockSparseArray, value)`.
203+
# Consider changing that behavior when possible.
189204
blocksparse_fill!(a, value)
190205
return a
191206
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,6 @@ function blocksparse_setindex!(
9696
end
9797

9898
function blocksparse_fill!(a::AbstractArray, value)
99-
if iszero(value)
100-
# This drops all of the blocks.
101-
sparse_zero!(blocks(a))
102-
return a
103-
end
10499
for b in BlockRange(a)
105100
# We can't use:
106101
# ```julia
@@ -284,12 +279,14 @@ function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) w
284279
return a[Tuple(I)...]
285280
end
286281
function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N}
287-
parent_blocks = view(blocks(parent(a.array)), axes(a)...)
282+
parent_blocks = @view blocks(parent(a.array))[blockrange(a)...]
288283
# TODO: The following line is required to instantiate
289284
# uninstantiated blocks, maybe use `@view!` instead,
290285
# or some other code pattern.
291286
parent_blocks[I...] = parent_blocks[I...]
292-
return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] =
287+
# TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
288+
block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a)))
289+
return parent_blocks[I...][blockindices(parent(a.array), block, a.array.indices)...] =
293290
value
294291
end
295292
function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}

test/backup/runtests.jl

Lines changed: 0 additions & 147 deletions
This file was deleted.

test/test_basics.jl

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using BlockArrays:
1313
blocksize,
1414
blocksizes,
1515
mortar
16+
using Compat: @compat
1617
using LinearAlgebra: mul!
1718
using NDTensors.BlockSparseArrays:
1819
@view!, BlockSparseArray, block_nstored, block_reshape, view!
@@ -23,22 +24,21 @@ include("TestBlockSparseArraysUtils.jl")
2324
@testset "BlockSparseArrays (eltype=$elt)" for elt in
2425
(Float32, Float64, ComplexF32, ComplexF64)
2526
@testset "Broken" begin
26-
a = BlockSparseArray{elt}([2, 3], [3, 4])
27-
@test_broken a[Block(1, 2)] .= 1
28-
2927
a = BlockSparseArray{elt}([2, 3], [3, 4])
3028
b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]]
29+
@test b isa SubArray{<:Any,<:Any,<:BlockSparseArray}
3130
@test_broken b[2:4, 2:4]
3231

32+
a = BlockSparseArray{elt}([2, 3], [3, 4])
33+
b = @views a[[Block(2), Block(1)], [Block(2), Block(1)]][Block(1, 1)]
34+
@test_broken b isa SubArray{<:Any,<:Any,<:BlockSparseArray}
35+
3336
a = BlockSparseArray{elt}([2, 3], [3, 4])
3437
b = @views a[Block(1, 1)][1:2, 1:1]
38+
@test b isa SubArray{<:Any,<:Any,<:BlockSparseArray}
3539
for i in parentindices(b)
3640
@test_broken i isa BlockSlice{<:BlockIndexRange{1}}
3741
end
38-
39-
a = BlockSparseArray{elt}([2, 3], [3, 4])
40-
b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]]
41-
@test_broken b[Block(1, 1)] = randn(3, 3)
4242
end
4343
@testset "Basics" begin
4444
a = BlockSparseArray{elt}([2, 3], [2, 3])
@@ -82,6 +82,26 @@ include("TestBlockSparseArraysUtils.jl")
8282
@test block_nstored(a) == 2
8383
@test nstored(a) == 2 * 4 + 3 * 3
8484

85+
a = BlockSparseArray{elt}([2, 3], [3, 4])
86+
a[Block(1, 2)] .= 2
87+
@test eltype(a) == elt
88+
@test all(==(2), a[Block(1, 2)])
89+
@test iszero(a[Block(1, 1)])
90+
@test iszero(a[Block(2, 1)])
91+
@test iszero(a[Block(2, 2)])
92+
@test block_nstored(a) == 1
93+
@test nstored(a) == 2 * 4
94+
95+
a = BlockSparseArray{elt}([2, 3], [3, 4])
96+
a[Block(1, 2)] .= 0
97+
@test eltype(a) == elt
98+
@test iszero(a[Block(1, 1)])
99+
@test iszero(a[Block(2, 1)])
100+
@test iszero(a[Block(1, 2)])
101+
@test iszero(a[Block(2, 2)])
102+
@test block_nstored(a) == 1
103+
@test nstored(a) == 2 * 4
104+
85105
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
86106
@views for b in [Block(1, 2), Block(2, 1)]
87107
a[b] = randn(elt, size(a[b]))
@@ -490,13 +510,45 @@ include("TestBlockSparseArraysUtils.jl")
490510
@test b[Block(2, 2)] == x
491511
end
492512

513+
function f1()
514+
a = BlockSparseArray{elt}([2, 3], [3, 4])
515+
b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]]
516+
x = randn(elt, 3, 4)
517+
b[Block(1, 1)] .= x
518+
return (; a, b, x)
519+
end
520+
function f2()
521+
a = BlockSparseArray{elt}([2, 3], [3, 4])
522+
b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]]
523+
x = randn(elt, 3, 4)
524+
b[Block(1, 1)] = x
525+
return (; a, b, x)
526+
end
527+
for abx in (f1(), f2())
528+
@compat (; a, b, x) = abx
529+
@test b isa SubArray{<:Any,<:Any,<:BlockSparseArray}
530+
@test block_nstored(b) == 1
531+
@test b[Block(1, 1)] == x
532+
for blck in [Block(2, 1), Block(1, 2), Block(2, 2)]
533+
@test iszero(b[blck])
534+
end
535+
@test block_nstored(a) == 1
536+
@test a[Block(2, 2)] == x
537+
for blck in [Block(1, 1), Block(2, 1), Block(1, 2)]
538+
@test iszero(a[blck])
539+
end
540+
@test_throws DimensionMismatch b[Block(1, 1)] .= randn(2, 3)
541+
end
542+
493543
a = BlockSparseArray{elt}([2, 3], [3, 4])
494-
b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]]
495-
x = randn(elt, 3, 4)
496-
b[Block(1, 1)] .= x
497-
@test b[Block(1, 1)] == x
498-
@test a[Block(2, 2)] == x
499-
@test_throws DimensionMismatch b[Block(1, 1)] .= randn(2, 3)
544+
b = @views a[[Block(2), Block(1)], [Block(2), Block(1)]][Block(2, 1)]
545+
@test iszero(b)
546+
@test size(b) == (2, 4)
547+
x = randn(elt, 2, 4)
548+
b .= x
549+
@test b == x
550+
@test a[Block(1, 2)] == x
551+
@test block_nstored(a) == 1
500552

501553
a = BlockSparseArray{elt}([2, 3], [3, 4])
502554
b = @view a[[Block(2), Block(1)], [Block(2), Block(1)]]

0 commit comments

Comments
 (0)