@@ -124,50 +124,56 @@ Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer
124
124
125
125
Base. getindex (A:: AbstractGPUArray , I... ) = _getindex (A, to_indices (A, I)... )
126
126
127
- function _getindex (src:: AbstractGPUArray , Is:: Vararg{<:Any,N} ) where {N}
127
+ function _getindex (src:: AbstractGPUArray , Is... )
128
128
shape = Base. index_shape (Is... )
129
129
dest = similar (src, shape)
130
130
any (isempty, Is) && return dest # indexing with empty array
131
131
idims = map (length, Is)
132
132
133
- function kernel (ctx:: AbstractKernelContext , dest:: AbstractArray , src:: AbstractArray , idims, Is)
133
+ AT = typeof (src). name. wrapper
134
+ # NOTE: we are pretty liberal here supporting non-GPU indices...
135
+ gpu_call (getindex_kernel, dest, src, idims, adapt (AT, Is)... )
136
+ return dest
137
+ end
138
+
139
+ @generated function getindex_kernel (ctx:: AbstractKernelContext , dest, src, idims,
140
+ Is:: Vararg{<:Any,N} ) where {N}
141
+ quote
134
142
i = @linearidx dest
135
143
@inbounds begin
136
144
is = CartesianIndices (idims)[i]
137
- idx = ntuple (dim -> @inbounds ( Is[dim ][is[dim]]), N)
138
- dest[i] = getindex ( src, idx ... )
145
+ @nexprs $ N i -> I_i = Is[i ][is[i]]
146
+ dest[i] = @ncall $ N getindex src i -> I_i
139
147
end
140
148
return
141
149
end
142
-
143
- AT = typeof (src). name. wrapper
144
- # NOTE: we are pretty liberal here supporting non-GPU indices...
145
- gpu_call (kernel, dest, src, idims, adapt (AT, Is); name= " getindex!" )
146
- return dest
147
150
end
148
151
149
152
Base. setindex! (A:: AbstractGPUArray , v, I... ) = _setindex! (A, v, to_indices (A, I)... )
150
153
151
- function _setindex! (dest:: AbstractGPUArray , src, Is:: Vararg{<:Any,N} ) where {N}
154
+ function _setindex! (dest:: AbstractGPUArray , src, Is... )
152
155
isempty (Is) && return dest
153
156
idims = length .(Is)
154
157
len = prod (idims)
155
158
len== 0 && return dest
156
159
157
- function kernel (ctx:: AbstractKernelContext , dest, src, idims, len, Is)
160
+ AT = typeof (dest). name. wrapper
161
+ # NOTE: we are pretty liberal here supporting non-GPU sources and indices...
162
+ gpu_call (setindex_kernel, dest, adapt (AT, src), idims, len, adapt (AT, Is)... ;
163
+ total_threads= len)
164
+ return dest
165
+ end
166
+
167
+ @generated function setindex_kernel (ctx:: AbstractKernelContext , dest, src, idims, len,
168
+ Is:: Vararg{<:Any,N} ) where {N}
169
+ quote
158
170
i = linear_index (ctx)
159
171
i > len && return
160
172
@inbounds begin
161
173
is = CartesianIndices (idims)[i]
162
- idx = ntuple (dim -> @inbounds ( Is[dim ][is[dim]]), N)
163
- setindex! ( dest, src[i], idx ... )
174
+ @nexprs $ N i -> I_i = Is[i ][is[i]]
175
+ @ncall $ N setindex! dest src[i] i -> I_i
164
176
end
165
177
return
166
178
end
167
-
168
- AT = typeof (dest). name. wrapper
169
- # NOTE: we are pretty liberal here supporting non-GPU sources and indices...
170
- gpu_call (kernel, dest, adapt (AT, src), idims, len, adapt (AT, Is);
171
- total_threads= len, name= " setindex!" )
172
- return dest
173
179
end
0 commit comments