@@ -122,56 +122,52 @@ Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer
122
122
123
123
# generalized multidimensional indexing
124
124
125
- @generated function index_kernel (ctx:: AbstractKernelContext , dest:: AbstractArray , src:: AbstractArray , idims, Is)
126
- N = length (Is. parameters)
127
- quote
128
- i = @linearidx dest
129
- is = CartesianIndices (idims)[i]
130
- @nexprs $ N i -> @inbounds I_i = Is[i][is[i]]
131
- @inbounds dest[i] = @ncall $ N getindex src i -> I_i
132
- return
133
- end
134
- end
135
-
136
- function Base. getindex (A:: AbstractGPUArray , I... )
137
- _getindex (A, to_indices (A, I)... )
138
- end
125
+ Base. getindex (A:: AbstractGPUArray , I... ) = _getindex (A, to_indices (A, I)... )
139
126
140
- function _getindex (src:: AbstractGPUArray , Is... )
127
+ function _getindex (src:: AbstractGPUArray , Is:: Vararg{<:Any,N} ) where {N}
141
128
shape = Base. index_shape (Is... )
142
129
dest = similar (src, shape)
143
130
any (isempty, Is) && return dest # indexing with empty array
144
131
idims = map (length, Is)
145
- AT = typeof (src). name. wrapper
146
- # NOTE: we are pretty liberal here supporting non-GPU indices...
147
- gpu_call (index_kernel, dest, src, idims, adapt (AT, Is))
148
- return dest
149
- end
150
132
151
- @generated function setindex_kernel! (ctx:: AbstractKernelContext , dest:: AbstractArray , src, idims, Is, len)
152
- N = length (Is. parameters)
153
- idx = ntuple (i-> :(Is[$ i][is[$ i]]), N)
154
- quote
155
- i = linear_index (ctx)
156
- i > len && return
157
- is = CartesianIndices (idims)[i]
158
- @inbounds setindex! (dest, src[is], $ (idx... ))
133
+ function kernel (ctx:: AbstractKernelContext , dest:: AbstractArray , src:: AbstractArray , idims, Is)
134
+ i = @linearidx dest
135
+ @inbounds begin
136
+ is = CartesianIndices (idims)[i]
137
+ idx = ntuple (dim -> @inbounds (Is[dim][is[dim]]), N)
138
+ dest[i] = getindex (src, idx... )
139
+ end
159
140
return
160
141
end
161
- end
162
142
163
- function Base. setindex! (A:: AbstractGPUArray , v, I... )
164
- _setindex! (A, v, to_indices (A, I)... )
143
+ AT = typeof (src). name. wrapper
144
+ # NOTE: we are pretty liberal here supporting non-GPU indices...
145
+ gpu_call (kernel, dest, src, idims, adapt (AT, Is); name= " getindex!" )
146
+ return dest
165
147
end
166
148
167
- function _setindex! (dest:: AbstractGPUArray , src, Is... )
149
+ Base. setindex! (A:: AbstractGPUArray , v, I... ) = _setindex! (A, v, to_indices (A, I)... )
150
+
151
+ function _setindex! (dest:: AbstractGPUArray , src, Is:: Vararg{<:Any,N} ) where {N}
168
152
isempty (Is) && return dest
169
153
idims = length .(Is)
170
154
len = prod (idims)
171
155
len== 0 && return dest
156
+
157
+ function kernel (ctx:: AbstractKernelContext , dest, src, idims, len, Is)
158
+ i = linear_index (ctx)
159
+ i > len && return
160
+ @inbounds begin
161
+ is = CartesianIndices (idims)[i]
162
+ idx = ntuple (dim -> @inbounds (Is[dim][is[dim]]), N)
163
+ setindex! (dest, src[i], idx... )
164
+ end
165
+ return
166
+ end
167
+
172
168
AT = typeof (dest). name. wrapper
173
169
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
174
- gpu_call (setindex_kernel! , dest, adapt (AT, src), idims, adapt (AT, Is), len ;
175
- total_threads= len)
170
+ gpu_call (kernel , dest, adapt (AT, src), idims, len, adapt (AT, Is);
171
+ total_threads= len, name = " setindex! " )
176
172
return dest
177
173
end
0 commit comments