Skip to content

Commit 57f3321

Browse files
authored
[BlockSparseArrays] Redesign block views again (#1513)
* [BlockSparseArrays] Redesign block views again * [NDTensors] Bump to v0.3.38
1 parent 806397a commit 57f3321

File tree

4 files changed

+65
-7
lines changed

4 files changed

+65
-7
lines changed

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,33 @@ function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks)
396396
end
397397
end
398398

399+
# Represents a view of a block of a blocked array.
400+
struct BlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
401+
array::Array
402+
block::Tuple{Vararg{Block{1,Int},N}}
403+
end
404+
function Base.axes(a::BlockView)
405+
# TODO: Try to avoid conversion to `Base.OneTo{Int}`, or just convert
406+
# the element type to `Int` with `Int.(...)`.
407+
# When the axes of `a.array` are `GradedOneTo`, the block is `LabelledUnitRange`,
408+
# which has element type `LabelledInteger`. That causes conversion problems
409+
# in some generic Base Julia code, for example when printing `BlockView`.
410+
return ntuple(ndims(a)) do dim
411+
return Base.OneTo{Int}(only(axes(axes(a.array, dim)[a.block[dim]])))
412+
end
413+
end
414+
function Base.size(a::BlockView)
415+
return length.(axes(a))
416+
end
417+
function Base.getindex(a::BlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
418+
return blocks(a.array)[Int.(a.block)...][index...]
419+
end
420+
function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) where {N}
421+
blocks(a.array)[Int.(a.block)...] = blocks(a.array)[Int.(a.block)...]
422+
blocks(a.array)[Int.(a.block)...][index...] = value
423+
return a
424+
end
425+
399426
function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N}
400427
return view!(a, Tuple(index)...)
401428
end

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,18 @@ function Base.setindex!(
4242
blocksparse_setindex!(a, value, I...)
4343
return a
4444
end
45+
46+
function Base.setindex!(
47+
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
48+
) where {N}
49+
blocksize = ntuple(dim -> length(axes(a, dim)[I[dim]]), N)
50+
if size(value) blocksize
51+
throw(
52+
DimensionMismatch(
53+
"Trying to set block $(Block(Int.(I)...)), which has a size $blocksize, with data of size $(size(value)).",
54+
),
55+
)
56+
end
57+
blocks(a)[Int.(I)...] = value
58+
return a
59+
end

src/abstractblocksparsearray/views.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using BlockArrays: Block, BlockSlices
1+
using BlockArrays: BlockArrays, Block, BlockSlices, viewblock
22

33
function blocksparse_view(a, I...)
44
return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...)
@@ -22,3 +22,19 @@ function Base.view(
2222
)
2323
return blocksparse_view(a, I)
2424
end
25+
26+
# Specialized code for getting the view of a block.
27+
function BlockArrays.viewblock(
28+
a::AbstractBlockSparseArray{<:Any,N}, block::Block{N}
29+
) where {N}
30+
return viewblock(a, Tuple(block)...)
31+
end
32+
function BlockArrays.viewblock(
33+
a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N}
34+
) where {N}
35+
I = CartesianIndex(Int.(block))
36+
if I stored_indices(blocks(a))
37+
return blocks(a)[I]
38+
end
39+
return BlockView(a, block)
40+
end

test/test_basics.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using BlockArrays:
1717
using Compat: @compat
1818
using LinearAlgebra: mul!
1919
using NDTensors.BlockSparseArrays:
20-
@view!, BlockSparseArray, block_nstored, block_reshape, view!
20+
@view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view!
2121
using NDTensors.SparseArrayInterface: nstored
2222
using NDTensors.TensorAlgebra: contract
2323
using Test: @test, @test_broken, @test_throws, @testset
@@ -362,10 +362,10 @@ include("TestBlockSparseArraysUtils.jl")
362362
b = @view a[Block(2, 2)]
363363
@test size(b) == (3, 4)
364364
for i in parentindices(b)
365-
@test i isa BlockSlice{<:Block{1}}
365+
@test i isa Base.OneTo{Int}
366366
end
367-
@test parentindices(b)[1] == BlockSlice(Block(2), 3:5)
368-
@test parentindices(b)[2] == BlockSlice(Block(2), 4:7)
367+
@test parentindices(b)[1] == 1:3
368+
@test parentindices(b)[2] == 1:4
369369

370370
a = BlockSparseArray{elt}([2, 3], [3, 4])
371371
b = @view a[Block(2, 2)[1:2, 2:2]]
@@ -392,9 +392,9 @@ include("TestBlockSparseArraysUtils.jl")
392392

393393
a = BlockSparseArray{elt}([2, 3], [3, 4])
394394
b = @views a[Block(2, 2)][1:2, 2:3]
395-
@test b isa SubArray{<:Any,<:Any,<:BlockSparseArray}
395+
@test b isa SubArray{<:Any,<:Any,<:BlockView}
396396
for i in parentindices(b)
397-
@test i isa BlockSlice{<:BlockIndexRange{1}}
397+
@test i isa UnitRange{Int}
398398
end
399399
x = randn(elt, 2, 2)
400400
b .= x

0 commit comments

Comments
 (0)