Skip to content

Commit 8b02fa6

Browse files
committed
Specialize getindex for indexing VectorOfArray with boolean mask and colon
1 parent 61a4883 commit 8b02fa6

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/vector_of_array.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
7373
return Adapt.adapt(__parameterless_type(T),reshape(reduce(hcat,vecs),size(A.u[1])...,length(A.u)))
7474
end
7575

76+
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
77+
I::AbstractArray{Bool},J::Colon...) where {T, N}
78+
@assert length(J) == ndims(A.u[1])+1-ndims(I)
79+
@assert size(I) == size(A)[1:ndims(A)-length(J)]
80+
vecs = vec.(A.u)
81+
return Base.getindex(Adapt.adapt(__parameterless_type(T),reshape(reduce(hcat,vecs),size(A.u[1])...,length(A.u))), I, J...)
82+
end
83+
7684
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Colon) where {T, N} = [A.u[j][i] for j in 1:length(A)]
7785
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, ::Colon,i::Int) where {T, N} = A.u[i]
7886
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,II::AbstractArray{Int}) where {T, N} = [A.u[j][i] for j in II]

test/gpu/vectorofarray_gpu.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
1-
using RecursiveArrayTools, CUDA
1+
using RecursiveArrayTools, CUDA, Test
22
CUDA.allowscalar(false)
33

4+
# Test indexing with colon
45
x = zeros(5)
56
y = VectorOfArray([x,x,x])
67
y[:,:]
78

89
x = CUDA.zeros(5)
910
y = VectorOfArray([x,x,x])
1011
y[:,:]
12+
13+
# Test indexing with boolean masks and colon
14+
nx, ny, nt = 3, 4, 5
15+
x = CUDA.rand(nx, ny, nt)
16+
m = CUDA.rand(nx, ny) .> 0.5
17+
x[m, :]
18+
19+
va = VectorOfArray([slice for slice in eachslice(x, dims=3)])
20+
@test va[m, :] x[m, :]
21+
va[m, :]
22+
23+
xc = Array(x)
24+
mc = Array(m)
25+
xc[mc, :]
26+

0 commit comments

Comments
 (0)