@@ -3,11 +3,14 @@ using BlockArrays:
3
3
AbstractBlockArray,
4
4
AbstractBlockVector,
5
5
Block,
6
+ BlockIndex,
7
+ BlockIndexRange,
6
8
BlockRange,
9
+ BlockSlice,
10
+ BlockVector,
7
11
BlockedOneTo,
8
12
BlockedUnitRange,
9
- BlockVector,
10
- BlockSlice,
13
+ BlockedVector,
11
14
block,
12
15
blockaxes,
13
16
blockedrange,
@@ -17,8 +20,30 @@ using BlockArrays:
17
20
findblockindex
18
21
using Compat: allequal
19
22
using Dictionaries: Dictionary, Indices
20
- using .. GradedAxes: blockedunitrange_getindices
21
- using .. SparseArrayInterface: stored_indices
23
+ using .. GradedAxes: blockedunitrange_getindices, to_blockindices
24
+ using .. SparseArrayInterface: SparseArrayInterface, nstored, stored_indices
25
+
26
+ # A return type for `blocks(array)` when `array` isn't blocked.
27
+ # Represents a vector with just that single block.
28
+ struct SingleBlockView{T,N,Array<: AbstractArray{T,N} } <: AbstractArray{T,N}
29
+ array:: Array
30
+ end
31
+ blocks_maybe_single (a) = blocks (a)
32
+ blocks_maybe_single (a:: Array ) = SingleBlockView (a)
33
+ function Base. getindex (a:: SingleBlockView{<:Any,N} , index:: Vararg{Int,N} ) where {N}
34
+ @assert all (isone, index)
35
+ return a. array
36
+ end
37
+
38
+ # A wrapper around a potentially blocked array that is not blocked.
39
+ struct NonBlockedArray{T,N,Array<: AbstractArray{T,N} } <: AbstractArray{T,N}
40
+ array:: Array
41
+ end
42
+ Base. size (a:: NonBlockedArray ) = size (a. array)
43
+ Base. getindex (a:: NonBlockedArray{<:Any,N} , I:: Vararg{Integer,N} ) where {N} = a. array[I... ]
44
+ BlockArrays. blocks (a:: NonBlockedArray ) = SingleBlockView (a. array)
45
+ const NonBlockedVector{T,Array} = NonBlockedArray{T,1 ,Array}
46
+ NonBlockedVector (array:: AbstractVector ) = NonBlockedArray (array)
22
47
23
48
# BlockIndices works around an issue that the indices of BlockSlice
24
49
# are restricted to AbstractUnitRange{Int}.
@@ -37,6 +62,43 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
37
62
@assert length (S. indices[Block (i)]) == length (i. indices)
38
63
return BlockSlice (S. blocks[Int (Block (i))], S. indices[Block (i)])
39
64
end
65
+
66
+ # This is used in slicing like:
67
+ # a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
68
+ # I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
69
+ # a[I, I]
70
+ function Base. getindex (
71
+ S:: BlockIndices{<:AbstractBlockVector{<:Block{1}}} , i:: BlockSlice{<:Block{1}}
72
+ )
73
+ # TODO : Check for conistency of indices.
74
+ # Wrapping the indices in `NonBlockedVector` reinterprets the blocked indices
75
+ # as a single block, since the result shouldn't be blocked.
76
+ return NonBlockedVector (BlockIndices (S. blocks[Block (i)], S. indices[Block (i)]))
77
+ end
78
+ function Base. getindex (
79
+ S:: BlockIndices{<:BlockedVector{<:Block{1},<:BlockRange{1}}} , i:: BlockSlice{<:Block{1}}
80
+ )
81
+ return i
82
+ end
83
+
84
+ # Used in indexing such as:
85
+ # ```julia
86
+ # a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
87
+ # I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
88
+ # b = @view a[I, I]
89
+ # @view b[Block(1, 1)[1:2, 2:2]]
90
+ # ```
91
+ # This is similar to the definition:
92
+ # blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})
93
+ function Base. getindex (
94
+ a:: NonBlockedVector{<:Integer,<:BlockIndices} , I:: UnitRange{<:Integer}
95
+ )
96
+ ax = only (axes (a. array. indices))
97
+ brs = to_blockindices (ax, I)
98
+ inds = blockedunitrange_getindices (ax, I)
99
+ return NonBlockedVector (a. array[BlockSlice (brs, inds)])
100
+ end
101
+
40
102
function Base. getindex (S:: BlockIndices , i:: BlockSlice{<:BlockRange{1}} )
41
103
# TODO : Check that `i.indices` is consistent with `S.indices`.
42
104
# TODO : Turn this into a `blockedunitrange_getindices` definition.
@@ -50,6 +112,34 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
50
112
return BlockIndices (subblocks, subindices)
51
113
end
52
114
115
+ # Used when performing slices like:
116
+ # @views a[[Block(2), Block(1)]][2:4, 2:4]
117
+ function Base. getindex (S:: BlockIndices , i:: BlockSlice{<:BlockVector{<:BlockIndex{1}}} )
118
+ subblocks = mortar (
119
+ map (blocks (i. block)) do br
120
+ return S. blocks[Int (Block (br))][only (br. indices)]
121
+ end ,
122
+ )
123
+ subindices = mortar (
124
+ map (blocks (i. block)) do br
125
+ S. indices[br]
126
+ end ,
127
+ )
128
+ return BlockIndices (subblocks, subindices)
129
+ end
130
+
131
+ # Similar to the definition of `BlockArrays.BlockSlices`:
132
+ # ```julia
133
+ # const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}}
134
+ # ```
135
+ # but includes `BlockIndices`, where the blocks aren't contiguous.
136
+ const BlockSliceCollection = Union{
137
+ Base. Slice,BlockSlice{<: BlockRange{1} },BlockIndices{<: Vector{<:Block{1}} }
138
+ }
139
+ const SubBlockSliceCollection = BlockIndices{
140
+ <: BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
141
+ }
142
+
53
143
# TODO : This is type piracy. This is used in `reindex` when making
54
144
# views of blocks of sliced block arrays, for example:
55
145
# ```julia
@@ -218,6 +308,12 @@ function blockrange(axis::AbstractUnitRange, r::UnitRange)
218
308
return findblock (axis, first (r)): findblock (axis, last (r))
219
309
end
220
310
311
+ # Occurs when slicing with `a[2:4, 2:4]`.
312
+ function blockrange (axis:: BlockedOneTo{<:Integer} , r:: BlockedUnitRange{<:Integer} )
313
+ # TODO : Check the blocks are commensurate.
314
+ return findblock (axis, first (r)): findblock (axis, last (r))
315
+ end
316
+
221
317
function blockrange (axis:: AbstractUnitRange , r:: Int )
222
318
# # return findblock(axis, r)
223
319
return error (" Slicing with integer values isn't supported." )
@@ -241,14 +337,17 @@ function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
241
337
return only (blockaxes (r))
242
338
end
243
339
244
- # This handles changing the blocking, for example :
340
+ # This handles block merging :
245
341
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
342
+ # I = BlockedVector(Block.(1:4), [2, 2])
343
+ # I = BlockVector(Block.(1:4), [2, 2])
246
344
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
345
+ # I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
247
346
# a[I, I]
248
- # TODO : Generalize to `AbstractBlockedUnitRange` and ` AbstractBlockVector`.
249
- function blockrange (axis :: BlockedOneTo{<:Integer} , r :: BlockVector{<:Integer} )
250
- # TODO : Probably this is incorrect and should be something like:
251
- # return findblock(axis, first(r)):findblock(axis, last(r))
347
+ function blockrange (axis :: BlockedOneTo{<:Integer} , r :: AbstractBlockVector{<:Block{1}} )
348
+ for b in r
349
+ @assert b ∈ blockaxes (axis, 1 )
350
+ end
252
351
return only (blockaxes (r))
253
352
end
254
353
@@ -287,6 +386,10 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
287
386
return only (blockaxes (axis))
288
387
end
289
388
389
+ function blockrange (axis:: AbstractUnitRange , r:: NonBlockedVector )
390
+ return Block (1 ): Block (1 )
391
+ end
392
+
290
393
function blockrange (axis:: AbstractUnitRange , r)
291
394
return error (" Slicing not implemented for range of type `$(typeof (r)) `." )
292
395
end
@@ -423,7 +526,18 @@ function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) wher
423
526
return a
424
527
end
425
528
426
- function view! (a:: BlockSparseArray{<:Any,N} , index:: Block{N} ) where {N}
529
+ function SparseArrayInterface. nstored (a:: BlockView )
530
+ # TODO : Store whether or not the block is stored already as
531
+ # a Bool in `BlockView`.
532
+ I = CartesianIndex (Int .(a. block))
533
+ # TODO : Use `block_stored_indices`.
534
+ if I ∈ stored_indices (blocks (a. array))
535
+ return nstored (blocks (a. array)[I])
536
+ end
537
+ return 0
538
+ end
539
+
540
+ function view! (a:: AbstractArray{<:Any,N} , index:: Block{N} ) where {N}
427
541
return view! (a, Tuple (index)... )
428
542
end
429
543
function view! (a:: AbstractArray{<:Any,N} , index:: Vararg{Block{1},N} ) where {N}
0 commit comments