Skip to content

Commit 0d11eec

Browse files
committed
Noncontiguous slicing on GPU
1 parent 0b786d7 commit 0d11eec

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/abstractblocksparsearray/unblockedsubarray.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@ function TypeParameterAccessors.similartype(arraytype::Type{<:UnblockedSubArray}
2727
return similartype(blocktype(parenttype(arraytype)), elt)
2828
end
2929

30+
function Base.similar(
31+
a::UnblockedSubArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
32+
)
33+
return similar(similartype(blocktype(parenttype(a)), elt), axes)
34+
end
35+
function Base.similar(a::UnblockedSubArray, elt::Type, size::Tuple{Int,Vararg{Int}})
36+
return similar(a, elt, Base.OneTo.(size))
37+
end
38+
39+
function ArrayLayouts.sub_materialize(a::UnblockedSubArray)
40+
a_cpu = adapt(Array, a)
41+
a_cpu′ = similar(a_cpu)
42+
a_cpu′ .= a_cpu
43+
if typeof(a) === typeof(a_cpu)
44+
return a_cpu′
45+
end
46+
a′ = similar(a)
47+
a′ .= a_cpu′
48+
return a′
49+
end
50+
3051
function Base.map!(
3152
f, a_dest::AbstractArray, a_src1::UnblockedSubArray, a_src_rest::UnblockedSubArray...
3253
)

test/test_basics.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ arrayts = (Array, JLArray)
5656
a[Block(1, 1)] = dev(randn(elt, 2, 2))
5757
a[Block(2, 2)] = dev(randn(elt, 3, 3))
5858
@test_broken a[:, [2, 4]]
59-
@test_broken a[[3, 5], [2, 4]]
6059

6160
# TODO: Fix this and turn it into a proper test.
6261
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
@@ -713,6 +712,13 @@ arrayts = (Array, JLArray)
713712
@test a[Block(2, 2)[1:2, 2:3]] == b
714713
@test blockstoredlength(a) == 1
715714

715+
# Non-contiguous slicing.
716+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
717+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
718+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
719+
I = ([3, 5], [2, 4])
720+
@test a[I...] == Array(a)[I...]
721+
716722
a = BlockSparseArray{elt}(undef, [2, 3], [2, 3])
717723
@views for b in [Block(1, 1), Block(2, 2)]
718724
a[b] = randn(elt, size(a[b]))

0 commit comments

Comments
 (0)