@@ -122,56 +122,55 @@ Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer
122
122
123
123
# generalized multidimensional indexing
124
124
125
- @generated function index_kernel (ctx:: AbstractKernelContext , dest:: AbstractArray , src:: AbstractArray , idims, Is)
126
- N = length (Is. parameters)
127
- quote
128
- i = @linearidx dest
129
- is = CartesianIndices (idims)[i]
130
- @nexprs $ N i -> @inbounds I_i = Is[i][is[i]]
131
- @inbounds dest[i] = @ncall $ N getindex src i -> I_i
132
- return
133
- end
134
- end
135
-
136
- function Base. getindex (A:: AbstractGPUArray , I... )
137
- _getindex (A, to_indices (A, I)... )
138
- end
125
+ Base. getindex (A:: AbstractGPUArray , I... ) = _getindex (A, to_indices (A, I)... )
139
126
140
127
function _getindex (src:: AbstractGPUArray , Is... )
141
128
shape = Base. index_shape (Is... )
142
129
dest = similar (src, shape)
143
130
any (isempty, Is) && return dest # indexing with empty array
144
131
idims = map (length, Is)
132
+
145
133
AT = typeof (src). name. wrapper
146
134
# NOTE: we are pretty liberal here supporting non-GPU indices...
147
- gpu_call (index_kernel , dest, src, idims, adapt (AT, Is))
135
+ gpu_call (getindex_kernel , dest, src, idims, adapt (AT, Is)... )
148
136
return dest
149
137
end
150
138
151
- @generated function setindex_kernel! (ctx:: AbstractKernelContext , dest:: AbstractArray , src, idims, Is, len)
152
- N = length (Is. parameters)
153
- idx = ntuple (i-> :(Is[$ i][is[$ i]]), N)
139
+ @generated function getindex_kernel (ctx:: AbstractKernelContext , dest, src, idims,
140
+ Is:: Vararg{<:Any,N} ) where {N}
154
141
quote
155
- i = linear_index (ctx)
156
- i > len && return
157
- is = CartesianIndices (idims)[i]
158
- @inbounds setindex! (dest, src[is], $ (idx... ))
142
+ i = @linearidx dest
143
+ is = @inbounds CartesianIndices (idims)[i]
144
+ @nexprs $ N i -> I_i = @inbounds (Is[i][is[i]])
145
+ val = @ncall $ N getindex src i -> I_i
146
+ @inbounds dest[i] = val
159
147
return
160
148
end
161
149
end
162
150
163
- function Base. setindex! (A:: AbstractGPUArray , v, I... )
164
- _setindex! (A, v, to_indices (A, I)... )
165
- end
151
+ Base. setindex! (A:: AbstractGPUArray , v, I... ) = _setindex! (A, v, to_indices (A, I)... )
166
152
167
153
function _setindex! (dest:: AbstractGPUArray , src, Is... )
168
154
isempty (Is) && return dest
169
155
idims = length .(Is)
170
156
len = prod (idims)
171
157
len== 0 && return dest
158
+
172
159
AT = typeof (dest). name. wrapper
173
160
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
174
- gpu_call (setindex_kernel! , dest, adapt (AT, src), idims, adapt (AT, Is), len ;
161
+ gpu_call (setindex_kernel, dest, adapt (AT, src), idims, len, adapt (AT, Is)... ;
175
162
total_threads= len)
176
163
return dest
177
164
end
165
+
166
+ @generated function setindex_kernel (ctx:: AbstractKernelContext , dest, src, idims, len,
167
+ Is:: Vararg{<:Any,N} ) where {N}
168
+ quote
169
+ i = linear_index (ctx)
170
+ i > len && return
171
+ is = @inbounds CartesianIndices (idims)[i]
172
+ @nexprs $ N i -> I_i = @inbounds (Is[i][is[i]])
173
+ @ncall $ N setindex! dest src[i] i -> I_i
174
+ return
175
+ end
176
+ end
0 commit comments