Skip to content

Commit cce9743

Browse files
committed
Add test for permutedims bug
1 parent a2cf2f5 commit cce9743

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,22 @@ function Base.similar(
273273
)
274274
return blocksparse_similar(a, elt, axes)
275275
end
276+
277+
# TODO: Implement this in a more generic way using a smarter `copyto!`,
278+
# which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls.
279+
# These are defined for now to avoid scalar indexing issues when there
280+
# are blocks on GPU.
281+
function Base.Array{T,N}(a::BlockSparseArrayLike{<:Any,N}) where {T,N}
282+
# First make it dense, then move to CPU.
283+
# Directly copying to CPU causes some issues with
284+
# scalar indexing on GPU which we have to investigate.
285+
a_dest = similartype(blocktype(a), T)(undef, size(a))
286+
a_dest .= a
287+
return Array{T,N}(a_dest)
288+
end
289+
function Base.Array{T}(a::BlockSparseArrayLike) where {T}
290+
return Array{T,ndims(a)}(a)
291+
end
292+
function Base.Array(a::BlockSparseArrayLike)
293+
return Array{eltype(a)}(a)
294+
end

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,15 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
297297
@test block_nstored(b) == 2
298298
@test nstored(b) == 2 * 4 + 3 * 3
299299

300+
a = dev(BlockSparseArray{elt}([1, 1, 1], [1, 2, 3], [2, 2, 1], [1, 2, 1]))
301+
a[Block(3, 2, 2, 3)] = dev(randn(1, 2, 2, 1))
302+
perm = (2, 3, 4, 1)
303+
for b in (PermutedDimsArray(a, perm), permutedims(a, perm))
304+
@test Array(b) == permutedims(Array(a), perm)
305+
@test issetequal(block_stored_indices(b), [Block(2, 2, 3, 3)])
306+
@test @allowscalar b[Block(2, 2, 3, 3)] == permutedims(a[Block(3, 2, 2, 3)], perm)
307+
end
308+
300309
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
301310
@views for b in [Block(1, 2), Block(2, 1)]
302311
a[b] = randn(elt, size(a[b]))

0 commit comments

Comments
 (0)