Skip to content

Commit 1cf96ab

Browse files
authored
Merge pull request #326 from JuliaGPU/tb/fix_indexing
Simplify indexing kernels and fix bug with misshape inputs.
2 parents fcbd2eb + 07e88a6 commit 1cf96ab

File tree

3 files changed

+30
-26
lines changed

3 files changed

+30
-26
lines changed

src/host/abstractarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ function Serialization.serialize(s::AbstractSerializer, t::T) where T <: Abstrac
3333
serialize_type(s, T)
3434
serialize(s, Array(t))
3535
end
36+
3637
function Serialization.deserialize(s::AbstractSerializer, ::Type{T}) where T <: AbstractGPUArray
3738
A = deserialize(s)
3839
T(A)

src/host/indexing.jl

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -122,56 +122,55 @@ Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer
122122

123123
# generalized multidimensional indexing
124124

125-
@generated function index_kernel(ctx::AbstractKernelContext, dest::AbstractArray, src::AbstractArray, idims, Is)
126-
N = length(Is.parameters)
127-
quote
128-
i = @linearidx dest
129-
is = CartesianIndices(idims)[i]
130-
@nexprs $N i -> @inbounds I_i = Is[i][is[i]]
131-
@inbounds dest[i] = @ncall $N getindex src i -> I_i
132-
return
133-
end
134-
end
135-
136-
function Base.getindex(A::AbstractGPUArray, I...)
137-
_getindex(A, to_indices(A, I)...)
138-
end
125+
Base.getindex(A::AbstractGPUArray, I...) = _getindex(A, to_indices(A, I)...)
139126

140127
function _getindex(src::AbstractGPUArray, Is...)
141128
shape = Base.index_shape(Is...)
142129
dest = similar(src, shape)
143130
any(isempty, Is) && return dest # indexing with empty array
144131
idims = map(length, Is)
132+
145133
AT = typeof(src).name.wrapper
146134
# NOTE: we are pretty liberal here supporting non-GPU indices...
147-
gpu_call(index_kernel, dest, src, idims, adapt(AT, Is))
135+
gpu_call(getindex_kernel, dest, src, idims, adapt(AT, Is)...)
148136
return dest
149137
end
150138

151-
@generated function setindex_kernel!(ctx::AbstractKernelContext, dest::AbstractArray, src, idims, Is, len)
152-
N = length(Is.parameters)
153-
idx = ntuple(i-> :(Is[$i][is[$i]]), N)
139+
@generated function getindex_kernel(ctx::AbstractKernelContext, dest, src, idims,
140+
Is::Vararg{<:Any,N}) where {N}
154141
quote
155-
i = linear_index(ctx)
156-
i > len && return
157-
is = CartesianIndices(idims)[i]
158-
@inbounds setindex!(dest, src[is], $(idx...))
142+
i = @linearidx dest
143+
is = @inbounds CartesianIndices(idims)[i]
144+
@nexprs $N i -> I_i = @inbounds(Is[i][is[i]])
145+
val = @ncall $N getindex src i -> I_i
146+
@inbounds dest[i] = val
159147
return
160148
end
161149
end
162150

163-
function Base.setindex!(A::AbstractGPUArray, v, I...)
164-
_setindex!(A, v, to_indices(A, I)...)
165-
end
151+
Base.setindex!(A::AbstractGPUArray, v, I...) = _setindex!(A, v, to_indices(A, I)...)
166152

167153
function _setindex!(dest::AbstractGPUArray, src, Is...)
168154
isempty(Is) && return dest
169155
idims = length.(Is)
170156
len = prod(idims)
171157
len==0 && return dest
158+
172159
AT = typeof(dest).name.wrapper
173160
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
174-
gpu_call(setindex_kernel!, dest, adapt(AT, src), idims, adapt(AT, Is), len;
161+
gpu_call(setindex_kernel, dest, adapt(AT, src), idims, len, adapt(AT, Is)...;
175162
total_threads=len)
176163
return dest
177164
end
165+
166+
@generated function setindex_kernel(ctx::AbstractKernelContext, dest, src, idims, len,
167+
Is::Vararg{<:Any,N}) where {N}
168+
quote
169+
i = linear_index(ctx)
170+
i > len && return
171+
is = @inbounds CartesianIndices(idims)[i]
172+
@nexprs $N i -> I_i = @inbounds(Is[i][is[i]])
173+
@ncall $N setindex! dest src[i] i -> I_i
174+
return
175+
end
176+
end

test/testsuite/indexing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,8 @@ end
139139
@test compare(a->a[i',:], AT, a)
140140
@test compare(a->a[view(i,1,:),:], AT, a)
141141
end
142+
143+
@testset "JuliaGPU/CUDA.jl#461: sliced setindex" begin
144+
@test compare((X,Y)->(X[1,:] = Y), AT, zeros(2,2), ones(2))
145+
end
142146
end

0 commit comments

Comments
 (0)