Skip to content

Commit 42413dd

Browse files
committed
Fix some broken slicing operations
1 parent 443ac70 commit 42413dd

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

src/fillarrays/kroneckerarray.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr
2222
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
2323

2424
_getindex(a::Eye, I1::Colon, I2::Colon) = a
25+
_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
26+
_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a
27+
_getindex(a::Eye, I1::Colon, I2::Base.Slice) = a
2528
_view(a::Eye, I1::Colon, I2::Colon) = a
2629
_view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
30+
_view(a::Eye, I1::Base.Slice, I2::Colon) = a
31+
_view(a::Eye, I1::Colon, I2::Base.Slice) = a
2732

2833
# Like `adapt` but preserves `Eye`.
2934
_adapt(to, a::Eye) = a

src/kroneckerarray.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,18 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {
167167
return a[I′...]
168168
end
169169

170+
# Indexing logic.
171+
function Base.to_indices(a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg})
172+
I1 = to_indices(arg1(a), arg1.(inds), arg1.(I))
173+
I2 = to_indices(arg2(a), arg2.(inds), arg2.(I))
174+
return I1 I2
175+
end
176+
170177
# Allow customizing for `FillArrays.Eye`.
171178
_getindex(a::AbstractArray, I...) = a[I...]
172-
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
173-
return _getindex(arg1(a), arg1.(I)...) _getindex(arg2(a), arg2.(I)...)
174-
end
175-
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
176-
return _getindex(arg1(a), arg1.(I)...) _getindex(arg2(a), arg2.(I)...)
179+
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N}) where {N}
180+
I′ = to_indices(a, I)
181+
return _getindex(arg1(a), arg1.(I′)...) _getindex(arg2(a), arg2.(I′)...)
177182
end
178183
# Fix ambigiuity error.
179184
Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]

test/test_blocksparsearrays.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ arrayts = (Array, JLArray)
4848
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
4949
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
5050
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
51-
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
51+
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
5252

5353
# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
5454
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
@@ -169,7 +169,7 @@ end
169169
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
170170
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
171171
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
172-
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
172+
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
173173

174174
# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
175175
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])

0 commit comments

Comments
 (0)