Skip to content

Commit 806397a

Browse files
authored
[BlockSparseArrays] Towards block merging (#1512)
* [BlockSparseArrays] Towards block merging * [NDTensors] Bump to v0.3.37
1 parent 0ef89ca commit 806397a

File tree

4 files changed

+51
-2
lines changed

4 files changed

+51
-2
lines changed

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,26 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
230230
return r
231231
end
232232

233+
# This handles changing the blocking, for example:
234+
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
235+
# I = blockedrange([4, 4])
236+
# a[I, I]
237+
# TODO: Generalize to `AbstractBlockedUnitRange`.
238+
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
239+
# TODO: Probably this is incorrect and should be something like:
240+
# return findblock(axis, first(r)):findblock(axis, last(r))
241+
return only(blockaxes(r))
242+
end
243+
244+
# This handles changing the blocking, for example:
245+
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
246+
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
247+
# a[I, I]
248+
# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
233249
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
234-
return error("Slicing not implemented for range of type `$(typeof(r))`.")
250+
# TODO: Probably this is incorrect and should be something like:
251+
# return findblock(axis, first(r)):findblock(axis, last(r))
252+
return only(blockaxes(r))
235253
end
236254

237255
using BlockArrays: BlockSlice

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Adapt: Adapt, WrappedArray
22
using BlockArrays:
33
BlockArrays,
4+
AbstractBlockVector,
45
AbstractBlockedUnitRange,
56
BlockIndexRange,
67
BlockRange,
@@ -40,8 +41,9 @@ function Base.to_indices(
4041
end
4142

4243
# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])]
44+
# a[BlockedVector([Block(2), Block(1)], [2]), BlockedVector([Block(2), Block(1)], [2])]
4345
function Base.to_indices(
44-
a::BlockSparseArrayLike, inds, I::Tuple{BlockVector{<:Block{1}},Vararg{Any}}
46+
a::BlockSparseArrayLike, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}}
4547
)
4648
return blocksparse_to_indices(a, inds, I)
4749
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using BlockArrays:
44
BlockIndex,
55
BlockVector,
66
BlockedUnitRange,
7+
BlockedVector,
78
block,
89
blockcheckbounds,
910
blocklengths,
@@ -46,6 +47,12 @@ function blocksparse_to_indices(a, inds, I::Tuple{BlockVector{<:Block{1}},Vararg
4647
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
4748
end
4849

50+
# TODO: Should this be combined with the version above?
51+
function blocksparse_to_indices(a, inds, I::Tuple{BlockedVector{<:Block{1}},Vararg{Any}})
52+
I1 = blockedunitrange_getindices(inds[1], I[1])
53+
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
54+
end
55+
4956
# TODO: Need to implement this!
5057
function block_merge end
5158

@@ -223,6 +230,9 @@ function Base.size(a::SparseSubArrayBlocks)
223230
return length.(axes(a))
224231
end
225232
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
233+
# TODO: Should this be defined as `@view a.array[Block(I)]` instead?
234+
## return @view a.array[Block(I)]
235+
226236
parent_blocks = @view blocks(parent(a.array))[blockrange(a)...]
227237
parent_block = parent_blocks[I...]
228238
# TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.

test/test_basics.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using BlockArrays:
77
BlockVector,
88
BlockedOneTo,
99
BlockedUnitRange,
10+
BlockedVector,
1011
blockedrange,
1112
blocklength,
1213
blocklengths,
@@ -23,6 +24,24 @@ using Test: @test, @test_broken, @test_throws, @testset
2324
include("TestBlockSparseArraysUtils.jl")
2425
@testset "BlockSparseArrays (eltype=$elt)" for elt in
2526
(Float32, Float64, ComplexF32, ComplexF64)
27+
@testset "Broken" begin
28+
a = BlockSparseArray{elt}([2, 2, 2, 2], [2, 2, 2, 2])
29+
@views for I in [Block(1, 1), Block(2, 2), Block(3, 3), Block(4, 4)]
30+
a[I] = randn(elt, size(a[I]))
31+
end
32+
33+
I = blockedrange([4, 4])
34+
b = @view a[I, I]
35+
@test_broken copy(b)
36+
37+
I = BlockedVector(Block.(1:4), [2, 2])
38+
b = @view a[I, I]
39+
@test_broken copy(b)
40+
41+
I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
42+
b = @view a[I, I]
43+
@test_broken copy(b)
44+
end
2645
@testset "Basics" begin
2746
a = BlockSparseArray{elt}([2, 3], [2, 3])
2847
@test a == BlockSparseArray{elt}(blockedrange([2, 3]), blockedrange([2, 3]))

0 commit comments

Comments
 (0)