Skip to content

Commit 2a8540e

Browse files
authored
Support WrappedGPUArray nd indexing by fusing vectorized fallback. (#512)
1 parent 4384d48 commit 2a8540e

File tree

3 files changed

+45
-6
lines changed

3 files changed

+45
-6
lines changed

src/host/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ end
301301
struct ToGPU
302302
array::AbstractGPUArray
303303
end
304+
ToGPU(A::WrappedArray) = ToGPU(parent(A))
304305
function Adapt.adapt_storage(to::ToGPU, xs::Array)
305306
arr = similar(to.array, eltype(xs), size(xs))
306307
copyto!(arr, xs)

src/host/indexing.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,24 @@ end
6161

6262
## vectorized indexing
6363

64-
function vectorized_getindex(src::AbstractGPUArray, Is...)
65-
shape = Base.index_shape(Is...)
66-
dest = similar(src, shape)
64+
function vectorized_getindex!(dest::AbstractGPUArray, src::AbstractArray, Is...)
6765
any(isempty, Is) && return dest # indexing with empty array
6866
idims = map(length, Is)
6967

7068
# NOTE: we are pretty liberal here supporting non-GPU indices...
71-
Is = map(x->adapt(ToGPU(src), x), Is)
69+
Is = map(adapt(ToGPU(dest)), Is)
7270
@boundscheck checkbounds(src, Is...)
7371

7472
gpu_call(getindex_kernel, dest, src, idims, Is...)
7573
return dest
7674
end
7775

76+
function vectorized_getindex(src::AbstractGPUArray, Is...)
77+
shape = Base.index_shape(Is...)
78+
dest = similar(src, shape)
79+
return vectorized_getindex!(dest, src, Is...)
80+
end
81+
7882
@generated function getindex_kernel(ctx::AbstractKernelContext, dest, src, idims,
7983
Is::Vararg{Any,N}) where {N}
8084
quote
@@ -87,7 +91,7 @@ end
8791
end
8892
end
8993

90-
function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
94+
function vectorized_setindex!(dest::AbstractArray, src, Is...)
9195
isempty(Is) && return dest
9296
idims = length.(Is)
9397
len = prod(idims)
@@ -101,7 +105,7 @@ function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
101105
end
102106

103107
# NOTE: we are pretty liberal here supporting non-GPU indices...
104-
Is = map(x->adapt(ToGPU(dest), x), Is)
108+
Is = map(adapt(ToGPU(dest)), Is)
105109
@boundscheck checkbounds(dest, Is...)
106110

107111
gpu_call(setindex_kernel, dest, adapt(ToGPU(dest), src), idims, len, Is...;
@@ -144,6 +148,23 @@ end
144148
end)
145149
end
146150

151+
## Vectorized index overloading for `WrappedGPUArray`
152+
# We'd better not to overload `getindex`/`setindex!` directly as otherwise
153+
# the ambiguities from the default scalar fallback become a mess.
154+
# The default `getindex` for `AbstractArray` follows a `similar`-`copyto!` style.
155+
# Thus we only dispatch the `copyto!` part (`Base._unsafe_getindex!`) to our implement.
156+
function Base._unsafe_getindex!(dest::AbstractGPUArray, src::AbstractArray, Is::Vararg{Union{Real, AbstractArray}, N}) where {N}
157+
return vectorized_getindex!(dest, src, Base.ensure_indexable(Is)...)
158+
end
159+
# Similar for `setindex!`, its default fallback is equivalent to `copyto!`.
160+
# We only dispatch the `copyto!` part (`Base._unsafe_setindex!`) to our implement.
161+
function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
162+
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
163+
end
164+
# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
165+
function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
166+
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
167+
end
147168

148169
# find*
149170

test/testsuite/indexing.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,23 @@ end
135135
@test compare(a->a[1:1,1:1], AT, a)
136136
@test compare(a->a[1:1,1:1,1:1], AT, a)
137137
end
138+
139+
@testset "getindex for WrapperGPUArray" begin
140+
a = rand(Float32, 5, 5)
141+
@test compare(a->a'[:, 1], AT, a)
142+
@test compare(a->Base.PermutedDimsArray(a, (2, 1))[2:-1:1, 1:2], AT, a)
143+
@test compare(a->LowerTriangular(a)[:], AT, a) broken=(string(AT) in ["MtlArray", "oneArray"])
144+
@test compare(a->Symmetric(a, :U)[a .> 0], AT, a)
145+
end
146+
147+
@testset "setindex! for WrapperGPUArray" for T in eltypes
148+
x = AT(zeros(T, (10, 10)))'
149+
y = AT(rand(T, (5, 5)))
150+
x[2:6, 2:6] = y
151+
@test Array(parent(x)[2:6, 2:6]) == Array(y)'
152+
x[2:6, 2:6] = 1:25
153+
@test Array(parent(x)[2:6, 2:6]) == reshape(1:25, 5, 5)'
154+
end
138155
end
139156

140157
@testsuite "indexing find" (AT, eltypes)->begin

0 commit comments

Comments
 (0)