61
61
62
62
# # vectorized indexing
63
63
64
- function vectorized_getindex (src:: AbstractGPUArray , Is... )
65
- shape = Base. index_shape (Is... )
66
- dest = similar (src, shape)
64
+ function vectorized_getindex! (dest:: AbstractGPUArray , src:: AbstractArray , Is... )
67
65
any (isempty, Is) && return dest # indexing with empty array
68
66
idims = map (length, Is)
69
67
70
68
# NOTE: we are pretty liberal here supporting non-GPU indices...
71
- Is = map (x -> adapt (ToGPU (src), x ), Is)
69
+ Is = map (adapt (ToGPU (dest) ), Is)
72
70
@boundscheck checkbounds (src, Is... )
73
71
74
72
gpu_call (getindex_kernel, dest, src, idims, Is... )
75
73
return dest
76
74
end
77
75
76
+ function vectorized_getindex (src:: AbstractGPUArray , Is... )
77
+ shape = Base. index_shape (Is... )
78
+ dest = similar (src, shape)
79
+ return vectorized_getindex! (dest, src, Is... )
80
+ end
81
+
78
82
@generated function getindex_kernel (ctx:: AbstractKernelContext , dest, src, idims,
79
83
Is:: Vararg{Any,N} ) where {N}
80
84
quote
87
91
end
88
92
end
89
93
90
- function vectorized_setindex! (dest:: AbstractGPUArray , src, Is... )
94
+ function vectorized_setindex! (dest:: AbstractArray , src, Is... )
91
95
isempty (Is) && return dest
92
96
idims = length .(Is)
93
97
len = prod (idims)
@@ -101,7 +105,7 @@ function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
101
105
end
102
106
103
107
# NOTE: we are pretty liberal here supporting non-GPU indices...
104
- Is = map (x -> adapt (ToGPU (dest), x ), Is)
108
+ Is = map (adapt (ToGPU (dest)), Is)
105
109
@boundscheck checkbounds (dest, Is... )
106
110
107
111
gpu_call (setindex_kernel, dest, adapt (ToGPU (dest), src), idims, len, Is... ;
144
148
end )
145
149
end
146
150
151
+ # # Vectorized index overloading for `WrappedGPUArray`
152
+ # We'd better not to overload `getindex`/`setindex!` directly as otherwise
153
+ # the ambiguities from the default scalar fallback become a mess.
154
+ # The default `getindex` for `AbstractArray` follows a `similar`-`copyto!` style.
155
+ # Thus we only dispatch the `copyto!` part (`Base._unsafe_getindex!`) to our implement.
156
+ function Base. _unsafe_getindex! (dest:: AbstractGPUArray , src:: AbstractArray , Is:: Vararg{Union{Real, AbstractArray}, N} ) where {N}
157
+ return vectorized_getindex! (dest, src, Base. ensure_indexable (Is)... )
158
+ end
159
+ # Similar for `setindex!`, its default fallback is equivalent to `copyto!`.
160
+ # We only dispatch the `copyto!` part (`Base._unsafe_setindex!`) to our implement.
161
+ function Base. _unsafe_setindex! (:: IndexStyle , A:: WrappedGPUArray , x, Is:: Vararg{Union{Real,AbstractArray}, N} ) where N
162
+ return vectorized_setindex! (A, x, Base. ensure_indexable (Is)... )
163
+ end
164
+ # And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
165
+ function Base. _unsafe_setindex! (:: IndexStyle , A:: Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray} , x, Is:: Vararg{Union{Real,AbstractArray}, N} ) where N
166
+ return vectorized_setindex! (A, x, Base. ensure_indexable (Is)... )
167
+ end
147
168
148
169
# find*
149
170
0 commit comments