Skip to content

Commit 1995934

Browse files
committed
Fix more slicing operations
1 parent 5b5d8d4 commit 1995934

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,9 @@ end
158158
const BlockSliceCollection = Union{
159159
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
160160
}
161-
const SubBlockSliceCollection = BlockIndices{
162-
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
161+
const SubBlockSliceCollection = Union{
162+
BlockSlice{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}},
163+
BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}},
163164
}
164165

165166
# TODO: This is type piracy. This is used in `reindex` when making

src/abstractblocksparsearray/views.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,14 @@ end
9292
# TODO: Move to `GradedUnitRanges` or `BlockArraysExtensions`.
9393
to_block(I::Block{1}) = I
9494
to_block(I::BlockIndexRange{1}) = Block(I)
95+
to_block(I::BlockIndexVector) = Block(I)
9596
to_block_indices(I::Block{1}) = Colon()
9697
to_block_indices(I::BlockIndexRange{1}) = only(I.indices)
98+
to_block_indices(I::BlockIndexVector) = I.indices
9799

98100
function Base.view(
99-
a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Union{Block{1},BlockIndexRange{1}},N}
101+
a::AbstractBlockSparseArray{<:Any,N},
102+
I::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
100103
) where {N}
101104
return @views a[to_block.(I)...][to_block_indices.(I)...]
102105
end
@@ -108,7 +111,7 @@ function Base.view(
108111
end
109112
function Base.view(
110113
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N}},
111-
I::Vararg{Union{Block{1},BlockIndexRange{1}},N},
114+
I::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
112115
) where {T,N}
113116
return @views a[to_block.(I)...][to_block_indices.(I)...]
114117
end
@@ -205,8 +208,14 @@ function BlockArrays.viewblock(
205208
end
206209

207210
function to_blockindexrange(
208-
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}},
209-
I::Block{1},
211+
a::BlockSlice{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}}, I::Block{1}
212+
)
213+
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
214+
# work right now.
215+
return blocks(a.block)[Int(I)]
216+
end
217+
function to_blockindexrange(
218+
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}}, I::Block{1}
210219
)
211220
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
212221
# work right now.
@@ -247,20 +256,29 @@ end
247256

248257
# Block slice of the result of slicing `@view a[2:5, 2:5]`.
249258
# TODO: Move this to `BlockArraysExtensions`.
250-
const BlockedSlice = Union{
251-
BlockSlice{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}},
252-
BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}},
259+
const BlockIndexRangeSlice = BlockSlice{
260+
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
261+
}
262+
const BlockIndexVectorSlice = BlockSlice{
263+
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}
253264
}
265+
const BlockedSlice = Union{BlockIndexRangeSlice,BlockIndexVectorSlice}
254266

255267
function Base.view(
256268
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
257269
block::Union{Block{N},BlockIndexRange{N}},
258270
) where {T,N}
259271
return viewblock(a, block)
260272
end
273+
function Base.view(
274+
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockIndexRangeSlice,N}}},
275+
block::Union{Block{N},BlockIndexRange{N}},
276+
) where {T,N}
277+
return viewblock(a, block)
278+
end
261279
function Base.view(
262280
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
263-
block::Vararg{Union{Block{1},BlockIndexRange{1}},N},
281+
block::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
264282
) where {T,N}
265283
return viewblock(a, block...)
266284
end

0 commit comments

Comments
 (0)