Skip to content

Commit eeb23d2

Browse files
committed
Go back to generated functions.
Some cases didn't infer properly with ntuple, and using a generator makes it possible to re-use the Cartesian nexprs stuff.
1 parent 75f8b40 commit eeb23d2

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

src/host/indexing.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -124,50 +124,56 @@ Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer
124124

125125
Base.getindex(A::AbstractGPUArray, I...) = _getindex(A, to_indices(A, I)...)
126126

127-
function _getindex(src::AbstractGPUArray, Is::Vararg{<:Any,N}) where {N}
127+
function _getindex(src::AbstractGPUArray, Is...)
128128
shape = Base.index_shape(Is...)
129129
dest = similar(src, shape)
130130
any(isempty, Is) && return dest # indexing with empty array
131131
idims = map(length, Is)
132132

133-
function kernel(ctx::AbstractKernelContext, dest::AbstractArray, src::AbstractArray, idims, Is)
133+
AT = typeof(src).name.wrapper
134+
# NOTE: we are pretty liberal here supporting non-GPU indices...
135+
gpu_call(getindex_kernel, dest, src, idims, adapt(AT, Is)...)
136+
return dest
137+
end
138+
139+
@generated function getindex_kernel(ctx::AbstractKernelContext, dest, src, idims,
140+
Is::Vararg{<:Any,N}) where {N}
141+
quote
134142
i = @linearidx dest
135143
@inbounds begin
136144
is = CartesianIndices(idims)[i]
137-
idx = ntuple(dim -> @inbounds(Is[dim][is[dim]]), N)
138-
dest[i] = getindex(src, idx...)
145+
@nexprs $N i -> I_i = Is[i][is[i]]
146+
dest[i] = @ncall $N getindex src i -> I_i
139147
end
140148
return
141149
end
142-
143-
AT = typeof(src).name.wrapper
144-
# NOTE: we are pretty liberal here supporting non-GPU indices...
145-
gpu_call(kernel, dest, src, idims, adapt(AT, Is); name="getindex!")
146-
return dest
147150
end
148151

149152
Base.setindex!(A::AbstractGPUArray, v, I...) = _setindex!(A, v, to_indices(A, I)...)
150153

151-
function _setindex!(dest::AbstractGPUArray, src, Is::Vararg{<:Any,N}) where {N}
154+
function _setindex!(dest::AbstractGPUArray, src, Is...)
152155
isempty(Is) && return dest
153156
idims = length.(Is)
154157
len = prod(idims)
155158
len==0 && return dest
156159

157-
function kernel(ctx::AbstractKernelContext, dest, src, idims, len, Is)
160+
AT = typeof(dest).name.wrapper
161+
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
162+
gpu_call(setindex_kernel, dest, adapt(AT, src), idims, len, adapt(AT, Is)...;
163+
total_threads=len)
164+
return dest
165+
end
166+
167+
@generated function setindex_kernel(ctx::AbstractKernelContext, dest, src, idims, len,
168+
Is::Vararg{<:Any,N}) where {N}
169+
quote
158170
i = linear_index(ctx)
159171
i > len && return
160172
@inbounds begin
161173
is = CartesianIndices(idims)[i]
162-
idx = ntuple(dim -> @inbounds(Is[dim][is[dim]]), N)
163-
setindex!(dest, src[i], idx...)
174+
@nexprs $N i -> I_i = Is[i][is[i]]
175+
@ncall $N setindex! dest src[i] i -> I_i
164176
end
165177
return
166178
end
167-
168-
AT = typeof(dest).name.wrapper
169-
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
170-
gpu_call(kernel, dest, adapt(AT, src), idims, len, adapt(AT, Is);
171-
total_threads=len, name="setindex!")
172-
return dest
173179
end

0 commit comments

Comments
 (0)