@@ -6,6 +6,14 @@ using GPUArrays
6
6
7
7
export JLArray
8
8
9
+
10
+ #
11
+ # Host array
12
+ #
13
+
14
+ # the definition of a host array type, implementing different Base interfaces
15
+ # to make it function properly and behave like the Base Array type.
16
+
9
17
struct JLArray{T, N} <: AbstractGPUArray{T, N}
10
18
data:: Array{T, N}
11
19
dims:: Dims{N}
@@ -15,12 +23,7 @@ struct JLArray{T, N} <: AbstractGPUArray{T, N}
15
23
end
16
24
end
17
25
18
-
19
- #
20
- # AbstractArray interface
21
- #
22
-
23
- # # typical constructors
26
+ # # constructors
24
27
25
28
# type and dimensionality specified, accepting dims as tuples of Ints
26
29
JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N} =
139
142
# AbstractGPUArray interface
140
143
#
141
144
145
+ # implementation of GPUArrays-specific interfaces
146
+
142
147
GPUArrays. unsafe_reinterpret (:: Type{T} , A:: JLArray , size:: Tuple ) where T =
143
148
reshape (reinterpret (T, A. data), size)
144
149
@@ -177,7 +182,7 @@ function JLState(state::JLState{N}, threadidx::NTuple{N}) where N
177
182
)
178
183
end
179
184
180
- to_device (state, x:: JLArray ) = x. data
185
+ to_device (state, x:: JLArray{T,N} ) where {T,N} = JLDeviceArray {T,N} ( x. data, x . dims)
181
186
to_device (state, x:: Tuple ) = to_device .(Ref (state), x)
182
187
to_device (state, x:: Base.RefValue{<: JLArray} ) = Base. RefValue (to_device (state, x[]))
183
188
to_device (state, x) = x
@@ -205,31 +210,6 @@ function GPUArrays._gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tup
205
210
end
206
211
207
212
208
- # # gpu intrinsics
209
-
210
- @inline function GPUArrays. synchronize_threads (:: JLState )
211
- # All threads are getting started asynchronously, so a yield will yield to the next
212
- # execution of the same function, which should call yield at the exact same point in the
213
- # program, leading to a chain of yields effectively syncing the tasks (threads).
214
- yield ()
215
- return
216
- end
217
-
218
- function GPUArrays. LocalMemory (state:: JLState , :: Type{T} , :: Val{N} , :: Val{C} ) where {T, N, C}
219
- state. localmem_counter += 1
220
- lmems = state. localmems[blockidx_x (state)]
221
-
222
- # first invocation in block
223
- if length (lmems) < state. localmem_counter
224
- lmem = fill (zero (T), N)
225
- push! (lmems, lmem)
226
- return lmem
227
- else
228
- return lmems[state. localmem_counter]
229
- end
230
- end
231
-
232
-
233
213
# # device properties
234
214
235
215
struct JLDevice end
@@ -249,24 +229,65 @@ GPUArrays.blasbuffer(A::JLArray) = A.data
249
229
250
230
251
231
#
252
- # AbstractDeviceArray interface
232
+ # Device array
253
233
#
254
234
255
- function GPUArrays. AbstractDeviceArray (ptr:: Array , shape:: NTuple{N, Integer} ) where N
256
- reshape (ptr, shape)
235
+ # definition of a minimal device array type that supports the subset of operations
236
+ # that are used in GPUArrays kernels
237
+
238
+ struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
239
+ data:: Array{T, N}
240
+ dims:: Dims{N}
241
+
242
+ function JLDeviceArray {T,N} (data:: Array{T, N} , dims:: Dims{N} ) where {T,N}
243
+ new (data, dims)
244
+ end
257
245
end
258
- function GPUArrays. AbstractDeviceArray (ptr:: Array , shape:: Vararg{Integer, N} ) where N
259
- reshape (ptr, shape)
246
+
247
+ function GPUArrays. LocalMemory (state:: JLState , :: Type{T} , :: Val{dims} , :: Val{id} ) where {T, dims, id}
248
+ state. localmem_counter += 1
249
+ lmems = state. localmems[blockidx_x (state)]
250
+
251
+ # first invocation in block
252
+ data = if length (lmems) < state. localmem_counter
253
+ lmem = fill (zero (T), dims)
254
+ push! (lmems, lmem)
255
+ lmem
256
+ else
257
+ lmems[state. localmem_counter]
258
+ end
259
+
260
+ N = length (dims)
261
+ JLDeviceArray {T,N} (data, tuple (dims... ))
260
262
end
261
263
262
264
265
+ # # array interface
266
+
267
+ Base. size (x:: JLDeviceArray ) = x. dims
268
+
269
+
263
270
# # indexing
264
271
272
+ @inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (A. data, index)
273
+ @inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (A. data, x, index)
274
+
265
275
for (i, sym) in enumerate ((:x , :y , :z ))
266
276
for f in (:blockidx , :blockdim , :threadidx , :griddim )
267
277
fname = Symbol (string (f, ' _' , sym))
268
278
@eval GPUArrays.$ fname (state:: JLState ) = Int (state.$ f[$ i])
269
279
end
270
280
end
271
281
282
+
283
+ # # synchronization
284
+
285
+ @inline function GPUArrays. synchronize_threads (:: JLState )
286
+ # All threads are getting started asynchronously, so a yield will yield to the next
287
+ # execution of the same function, which should call yield at the exact same point in the
288
+ # program, leading to a chain of yields effectively syncing the tasks (threads).
289
+ yield ()
290
+ return
291
+ end
292
+
272
293
end
0 commit comments