Skip to content

Commit 75f8b40

Browse files
committed
Simplify indexing kernels and fix bug with misshape inputs.
1 parent fcbd2eb commit 75f8b40

File tree

3 files changed

+35
-34
lines changed

3 files changed

+35
-34
lines changed

src/host/abstractarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ function Serialization.serialize(s::AbstractSerializer, t::T) where T <: Abstrac
3333
serialize_type(s, T)
3434
serialize(s, Array(t))
3535
end
36+
3637
function Serialization.deserialize(s::AbstractSerializer, ::Type{T}) where T <: AbstractGPUArray
3738
A = deserialize(s)
3839
T(A)

src/host/indexing.jl

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -122,56 +122,52 @@ Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer
122122

123123
# generalized multidimensional indexing
124124

125-
@generated function index_kernel(ctx::AbstractKernelContext, dest::AbstractArray, src::AbstractArray, idims, Is)
126-
N = length(Is.parameters)
127-
quote
128-
i = @linearidx dest
129-
is = CartesianIndices(idims)[i]
130-
@nexprs $N i -> @inbounds I_i = Is[i][is[i]]
131-
@inbounds dest[i] = @ncall $N getindex src i -> I_i
132-
return
133-
end
134-
end
135-
136-
function Base.getindex(A::AbstractGPUArray, I...)
137-
_getindex(A, to_indices(A, I)...)
138-
end
125+
Base.getindex(A::AbstractGPUArray, I...) = _getindex(A, to_indices(A, I)...)
139126

140-
function _getindex(src::AbstractGPUArray, Is...)
127+
function _getindex(src::AbstractGPUArray, Is::Vararg{<:Any,N}) where {N}
141128
shape = Base.index_shape(Is...)
142129
dest = similar(src, shape)
143130
any(isempty, Is) && return dest # indexing with empty array
144131
idims = map(length, Is)
145-
AT = typeof(src).name.wrapper
146-
# NOTE: we are pretty liberal here supporting non-GPU indices...
147-
gpu_call(index_kernel, dest, src, idims, adapt(AT, Is))
148-
return dest
149-
end
150132

151-
@generated function setindex_kernel!(ctx::AbstractKernelContext, dest::AbstractArray, src, idims, Is, len)
152-
N = length(Is.parameters)
153-
idx = ntuple(i-> :(Is[$i][is[$i]]), N)
154-
quote
155-
i = linear_index(ctx)
156-
i > len && return
157-
is = CartesianIndices(idims)[i]
158-
@inbounds setindex!(dest, src[is], $(idx...))
133+
function kernel(ctx::AbstractKernelContext, dest::AbstractArray, src::AbstractArray, idims, Is)
134+
i = @linearidx dest
135+
@inbounds begin
136+
is = CartesianIndices(idims)[i]
137+
idx = ntuple(dim -> @inbounds(Is[dim][is[dim]]), N)
138+
dest[i] = getindex(src, idx...)
139+
end
159140
return
160141
end
161-
end
162142

163-
function Base.setindex!(A::AbstractGPUArray, v, I...)
164-
_setindex!(A, v, to_indices(A, I)...)
143+
AT = typeof(src).name.wrapper
144+
# NOTE: we are pretty liberal here supporting non-GPU indices...
145+
gpu_call(kernel, dest, src, idims, adapt(AT, Is); name="getindex!")
146+
return dest
165147
end
166148

167-
function _setindex!(dest::AbstractGPUArray, src, Is...)
149+
Base.setindex!(A::AbstractGPUArray, v, I...) = _setindex!(A, v, to_indices(A, I)...)
150+
151+
function _setindex!(dest::AbstractGPUArray, src, Is::Vararg{<:Any,N}) where {N}
168152
isempty(Is) && return dest
169153
idims = length.(Is)
170154
len = prod(idims)
171155
len==0 && return dest
156+
157+
function kernel(ctx::AbstractKernelContext, dest, src, idims, len, Is)
158+
i = linear_index(ctx)
159+
i > len && return
160+
@inbounds begin
161+
is = CartesianIndices(idims)[i]
162+
idx = ntuple(dim -> @inbounds(Is[dim][is[dim]]), N)
163+
setindex!(dest, src[i], idx...)
164+
end
165+
return
166+
end
167+
172168
AT = typeof(dest).name.wrapper
173169
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
174-
gpu_call(setindex_kernel!, dest, adapt(AT, src), idims, adapt(AT, Is), len;
175-
total_threads=len)
170+
gpu_call(kernel, dest, adapt(AT, src), idims, len, adapt(AT, Is);
171+
total_threads=len, name="setindex!")
176172
return dest
177173
end

test/testsuite/indexing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,8 @@ end
139139
@test compare(a->a[i',:], AT, a)
140140
@test compare(a->a[view(i,1,:),:], AT, a)
141141
end
142+
143+
@testset "JuliaGPU/CUDA.jl#461: sliced setindex" begin
144+
@test compare((X,Y)->(X[1,:] = Y), AT, zeros(2,2), ones(2))
145+
end
142146
end

0 commit comments

Comments
 (0)