Skip to content

Commit c6d1d15

Browse files
committed
Format
1 parent f7754bc commit c6d1d15

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,9 @@ function SparseArraysBase.getunstoredindex(
410410
_perm(a.array),
411411
)
412412
end
413-
function SparseArraysBase.eachstoredindex(a::SparsePermutedDimsArrayBlocks)
413+
function SparseArraysBase.eachstoredindex(
414+
::IndexCartesian, a::SparsePermutedDimsArrayBlocks
415+
)
414416
return map(I -> _getindices(I, _perm(a.array)), eachstoredindex(blocks(parent(a.array))))
415417
end
416418
## TODO: Define `storedvalues` instead.

src/blocksparsearrayinterface/getunstoredblock.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ function Base.AbstractArray{A}(a::ZeroBlocks{N}) where {N,A}
1818
end
1919

2020
@inline function Base.getindex(a::ZeroBlocks{N,A}, I::Vararg{Int,N}) where {N,A}
21+
# TODO: Use `BlockArrays.eachblockaxes`.
2122
ax = ntuple(N) do d
2223
return only(axes(a.parentaxes[d][Block(I[d])]))
2324
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
88
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
99
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1010
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
11+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/test_basics.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ arrayts = (Array, JLArray)
399399
a_dest = a1 * a2
400400
@allowscalar @test Array(a_dest) Array(a1) * Array(a2)
401401
@test a_dest isa BlockSparseArray{elt}
402-
@test_broken blockstoredlength(a_dest) == 1
402+
@test blockstoredlength(a_dest) == 1
403403
end
404404
@testset "Matrix multiplication" begin
405405
a1 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
@@ -430,23 +430,23 @@ arrayts = (Array, JLArray)
430430
a2[Block(1, 2)] = dev(randn(elt, size(@view(a2[Block(1, 2)]))))
431431

432432
a_dest = cat(a1, a2; dims=1)
433-
@test_broken blockstoredlength(a_dest) == 2
433+
@test blockstoredlength(a_dest) == 2
434434
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3])
435-
@test_broken issetequal(eachblockstoredindex(a_dest), [Block(2, 1), Block(3, 2)])
435+
@test issetequal(eachblockstoredindex(a_dest), [Block(2, 1), Block(3, 2)])
436436
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
437437
@test a_dest[Block(3, 2)] == a2[Block(1, 2)]
438438

439439
a_dest = cat(a1, a2; dims=2)
440-
@test_broken blockstoredlength(a_dest) == 2
440+
@test blockstoredlength(a_dest) == 2
441441
@test blocklengths.(axes(a_dest)) == ([2, 3], [2, 3, 2, 3])
442-
@test_broken issetequal(eachblockstoredindex(a_dest), [Block(2, 1), Block(1, 4)])
442+
@test issetequal(eachblockstoredindex(a_dest), [Block(2, 1), Block(1, 4)])
443443
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
444444
@test a_dest[Block(1, 4)] == a2[Block(1, 2)]
445445

446446
a_dest = cat(a1, a2; dims=(1, 2))
447-
@test_broken blockstoredlength(a_dest) == 2
447+
@test blockstoredlength(a_dest) == 2
448448
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3, 2, 3])
449-
@test_broken issetequal(eachblockstoredindex(a_dest), [Block(2, 1), Block(3, 4)])
449+
@test issetequal(eachblockstoredindex(a_dest), [Block(2, 1), Block(3, 4)])
450450
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
451451
@test a_dest[Block(3, 4)] == a2[Block(1, 2)]
452452
end

0 commit comments

Comments
 (0)