Skip to content

Commit 77c269c

Browse files
committed
Fix tests
1 parent 1995934 commit 77c269c

File tree

4 files changed

+40
-19
lines changed

4 files changed

+40
-19
lines changed

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,17 @@ end
158158
const BlockSliceCollection = Union{
159159
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
160160
}
161+
const BlockIndexRangeSlice = BlockSlice{
162+
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
163+
}
164+
const BlockIndexRangeSlices = BlockIndices{
165+
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
166+
}
167+
const BlockIndexVectorSlices = BlockIndices{
168+
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}
169+
}
161170
const SubBlockSliceCollection = Union{
162-
BlockSlice{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}},
163-
BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}},
171+
BlockIndexRangeSlice,BlockIndexRangeSlices,BlockIndexVectorSlices
164172
}
165173

166174
# TODO: This is type piracy. This is used in `reindex` when making

src/abstractblocksparsearray/views.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ function to_blockindexrange(
214214
# work right now.
215215
return blocks(a.block)[Int(I)]
216216
end
217+
function to_blockindexrange(
218+
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange}}}, I::Block{1}
219+
)
220+
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
221+
# work right now.
222+
return blocks(a.blocks)[Int(I)]
223+
end
217224
function to_blockindexrange(
218225
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}}, I::Block{1}
219226
)
@@ -254,18 +261,10 @@ function BlockArrays.viewblock(
254261
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)
255262
end
256263

257-
# Block slice of the result of slicing `@view a[2:5, 2:5]`.
258-
# TODO: Move this to `BlockArraysExtensions`.
259-
const BlockIndexRangeSlice = BlockSlice{
260-
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
261-
}
262-
const BlockIndexVectorSlice = BlockSlice{
263-
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}
264-
}
265-
const BlockedSlice = Union{BlockIndexRangeSlice,BlockIndexVectorSlice}
266-
267264
function Base.view(
268-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
265+
a::SubArray{
266+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
267+
},
269268
block::Union{Block{N},BlockIndexRange{N}},
270269
) where {T,N}
271270
return viewblock(a, block)
@@ -277,13 +276,17 @@ function Base.view(
277276
return viewblock(a, block)
278277
end
279278
function Base.view(
280-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
279+
a::SubArray{
280+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
281+
},
281282
block::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
282283
) where {T,N}
283284
return viewblock(a, block...)
284285
end
285286
function BlockArrays.viewblock(
286-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
287+
a::SubArray{
288+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
289+
},
287290
block::Union{Block{N},BlockIndexRange{N}},
288291
) where {T,N}
289292
return viewblock(a, to_tuple(block)...)
@@ -294,7 +297,9 @@ blockedslice_blocks(x::BlockIndices) = x.blocks
294297

295298
# TODO: Define `@interface BlockSparseArrayInterface() viewblock`.
296299
function BlockArrays.viewblock(
297-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
300+
a::SubArray{
301+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
302+
},
298303
I::Vararg{Block{1},N},
299304
) where {T,N}
300305
# TODO: Use `reindex`, `to_indices`, etc.
@@ -308,7 +313,9 @@ function BlockArrays.viewblock(
308313
end
309314
# TODO: Define `@interface BlockSparseArrayInterface() viewblock`.
310315
function BlockArrays.viewblock(
311-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
316+
a::SubArray{
317+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
318+
},
312319
block::Vararg{BlockIndexRange{1},N},
313320
) where {T,N}
314321
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)

src/factorizations/svd.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
1212
alg::A
1313
end
1414

15-
function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
15+
function MatrixAlgebraKit.default_svd_algorithm(
16+
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
17+
)
1618
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
1719
error("unsupported type: $(blocktype(A))")
1820
# TODO: this is a hardcoded for now to get around this function not being defined in the

test/test_basics.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,11 @@ arrayts = (Array, JLArray)
588588
@views for b in [Block(1, 2), Block(2, 1)]
589589
a[b] = dev(randn(elt, size(a[b])))
590590
end
591-
for b in (a[2:4, 2:4], @view(a[2:4, 2:4]))
591+
I = 2:4
592+
I′, J′ = falses.(size(a))
593+
I′[I] .= true
594+
J′[I] .= true
595+
for b in (a[I, I], @view(a[I, I]), a[I′, J′], @view(a[I′, J′]))
592596
@allowscalar @test b == Array(a)[2:4, 2:4]
593597
@test size(b) == (3, 3)
594598
@test blocksize(b) == (2, 2)

0 commit comments

Comments
 (0)