1
- # reference implementation of the GPUArrays interfaces
1
+ # reference implementation of a CPU-based array type
2
+
3
+ module JLArrays
4
+
5
+ using GPUArrays
2
6
3
7
export JLArray
4
8
@@ -12,7 +16,11 @@ struct JLArray{T, N} <: AbstractGPUArray{T, N}
12
16
end
13
17
14
18
15
- # # construction
19
+ #
20
+ # AbstractArray interface
21
+ #
22
+
23
+ # # typical constructors
16
24
17
25
# type and dimensionality specified, accepting dims as tuples of Ints
18
26
JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N} =
@@ -29,7 +37,6 @@ JLArray{T}(::UndefInitializer, dims::Integer...) where {T} =
29
37
# empty vector constructor
30
38
JLArray {T,1} () where {T} = JLArray {T,1} (undef, 0 )
31
39
32
-
33
40
Base. similar (a:: JLArray{T,N} ) where {T,N} = JLArray {T,N} (undef, size (a))
34
41
Base. similar (a:: JLArray{T} , dims:: Base.Dims{N} ) where {T,N} = JLArray {T,N} (undef, dims)
35
42
Base. similar (a:: JLArray , :: Type{T} , dims:: Base.Dims{N} ) where {T,N} = JLArray {T,N} (undef, dims)
@@ -64,6 +71,8 @@ Base.convert(::Type{T}, x::T) where T <: JLArray = x
64
71
65
72
# # broadcast
66
73
74
+ using Base. Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
75
+
67
76
BroadcastStyle (:: Type{<:JLArray} ) = ArrayStyle {JLArray} ()
68
77
69
78
function Base. similar (bc:: Broadcasted{ArrayStyle{JLArray}} , :: Type{T} ) where T
72
81
73
82
Base. similar (bc:: Broadcasted{ArrayStyle{JLArray}} , :: Type{T} , dims... ) where {T} = JLArray {T} (undef, dims... )
74
83
75
- # # gpuarray interface
76
-
77
- struct JLBackend <: GPUBackend end
78
- backend (:: Type{<:JLArray} ) = JLBackend ()
79
-
80
- """
81
- Thread group local memory
82
- """
83
- struct LocalMem{N, T}
84
- x:: NTuple{N, Vector{T}}
85
- end
86
84
87
- to_device (state, x:: JLArray ) = x. data
88
- to_device (state, x:: Tuple ) = to_device .(Ref (state), x)
89
- to_device (state, x:: Base.RefValue{<: JLArray} ) = Base. RefValue (to_device (state, x[]))
90
- to_device (state, x) = x
91
-
92
- to_blocks (state, x) = x
93
- # unpacks local memory for each block
94
- to_blocks (state, x:: LocalMem ) = x. x[blockidx_x (state)]
95
-
96
- unsafe_reinterpret (:: Type{T} , A:: JLArray , size:: Tuple ) where T =
97
- reshape (reinterpret (T, A. data), size)
85
+ # # memory operations
98
86
99
87
function Base. copyto! (dest:: Array{T} , d_offset:: Integer ,
100
88
source:: JLArray{T} , s_offset:: Integer ,
@@ -103,6 +91,7 @@ function Base.copyto!(dest::Array{T}, d_offset::Integer,
103
91
@boundscheck checkbounds (source, s_offset+ amount- 1 )
104
92
copyto! (dest, d_offset, source. data, s_offset, amount)
105
93
end
94
+
106
95
function Base. copyto! (dest:: JLArray{T} , d_offset:: Integer ,
107
96
source:: Array{T} , s_offset:: Integer ,
108
97
amount:: Integer ) where T
@@ -111,6 +100,7 @@ function Base.copyto!(dest::JLArray{T}, d_offset::Integer,
111
100
copyto! (dest. data, d_offset, source, s_offset, amount)
112
101
dest
113
102
end
103
+
114
104
function Base. copyto! (dest:: JLArray{T} , d_offset:: Integer ,
115
105
source:: JLArray{T} , s_offset:: Integer ,
116
106
amount:: Integer ) where T
@@ -120,6 +110,45 @@ function Base.copyto!(dest::JLArray{T}, d_offset::Integer,
120
110
dest
121
111
end
122
112
113
+ # # fft
114
+
115
+ using AbstractFFTs
116
+
117
+ # defining our own plan type is the easiest way to pass around the plans in FFTW interface
118
+ # without ambiguities
119
+
120
+ struct FFTPlan{T}
121
+ p:: T
122
+ end
123
+
124
+ AbstractFFTs. plan_fft (A:: JLArray ; kw_args... ) = FFTPlan (plan_fft (A. data; kw_args... ))
125
+ AbstractFFTs. plan_fft! (A:: JLArray ; kw_args... ) = FFTPlan (plan_fft! (A. data; kw_args... ))
126
+ AbstractFFTs. plan_bfft! (A:: JLArray ; kw_args... ) = FFTPlan (plan_bfft! (A. data; kw_args... ))
127
+ AbstractFFTs. plan_bfft (A:: JLArray ; kw_args... ) = FFTPlan (plan_bfft (A. data; kw_args... ))
128
+ AbstractFFTs. plan_ifft! (A:: JLArray ; kw_args... ) = FFTPlan (plan_ifft! (A. data; kw_args... ))
129
+ AbstractFFTs. plan_ifft (A:: JLArray ; kw_args... ) = FFTPlan (plan_ifft (A. data; kw_args... ))
130
+
131
+ function Base.:(* )(plan:: FFTPlan , A:: JLArray )
132
+ x = plan. p * A. data
133
+ JLArray (x)
134
+ end
135
+
136
+
137
+
138
+ #
139
+ # AbstractGPUArray interface
140
+ #
141
+
142
+ GPUArrays. unsafe_reinterpret (:: Type{T} , A:: JLArray , size:: Tuple ) where T =
143
+ reshape (reinterpret (T, A. data), size)
144
+
145
+
146
+ # # execution
147
+
148
+ struct JLBackend <: AbstractGPUBackend end
149
+
150
+ GPUArrays. backend (:: Type{<:JLArray} ) = JLBackend ()
151
+
123
152
mutable struct JLState{N}
124
153
blockdim:: NTuple{N, Int}
125
154
griddim:: NTuple{N, Int}
@@ -148,27 +177,12 @@ function JLState(state::JLState{N}, threadidx::NTuple{N}) where N
148
177
)
149
178
end
150
179
151
- function LocalMemory (state:: JLState , :: Type{T} , :: Val{N} , :: Val{C} ) where {T, N, C}
152
- state. localmem_counter += 1
153
- lmems = state. localmems[blockidx_x (state)]
154
- # first invocation in block
155
- if length (lmems) < state. localmem_counter
156
- lmem = fill (zero (T), N)
157
- push! (lmems, lmem)
158
- return lmem
159
- else
160
- return lmems[state. localmem_counter]
161
- end
162
- end
163
-
164
- function AbstractDeviceArray (ptr:: Array , shape:: NTuple{N, Integer} ) where N
165
- reshape (ptr, shape)
166
- end
167
- function AbstractDeviceArray (ptr:: Array , shape:: Vararg{Integer, N} ) where N
168
- reshape (ptr, shape)
169
- end
180
+ to_device (state, x:: JLArray ) = x. data
181
+ to_device (state, x:: Tuple ) = to_device .(Ref (state), x)
182
+ to_device (state, x:: Base.RefValue{<: JLArray} ) = Base. RefValue (to_device (state, x[]))
183
+ to_device (state, x) = x
170
184
171
- function _gpu_call (:: JLBackend , f, A, args:: Tuple , blocks_threads:: Tuple{T, T} ) where T <: NTuple{N, Integer} where N
185
+ function GPUArrays . _gpu_call (:: JLBackend , f, A, args:: Tuple , blocks_threads:: Tuple{T, T} ) where T <: NTuple{N, Integer} where N
172
186
blocks, threads = blocks_threads
173
187
idx = ntuple (i-> 1 , length (blocks))
174
188
blockdim = blocks
@@ -177,10 +191,9 @@ function _gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T})
177
191
tasks = Array {Task} (undef, threads... )
178
192
for blockidx in CartesianIndices (blockdim)
179
193
state. blockidx = blockidx. I
180
- block_args = to_blocks .(Ref (state), device_args)
181
194
for threadidx in CartesianIndices (threads)
182
195
thread_state = JLState (state, threadidx. I)
183
- tasks[threadidx] = @async @allowscalar f (thread_state, block_args ... )
196
+ tasks[threadidx] = @async @allowscalar f (thread_state, device_args ... )
184
197
# TODO : @async obfuscates the trace to any exception which happens during f
185
198
end
186
199
for t in tasks
@@ -190,47 +203,69 @@ function _gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T})
190
203
return
191
204
end
192
205
193
- # "intrinsics"
194
- struct JLDevice end
195
- device (x:: JLArray ) = JLDevice ()
196
- threads (dev:: JLDevice ) = 256
197
-
198
- @inline function synchronize_threads (:: JLState )
199
- #=
200
- All threads are getting started asynchronously,so a yield will
201
- yield to the next execution of the same function, which should call yield
202
- at the exact same point in the program, leading to a chain of yields effectively syncing
203
- the tasks (threads).
204
- =#
206
+
207
+ # # gpu intrinsics
208
+
209
+ @inline function GPUArrays. synchronize_threads (:: JLState )
210
+ # All threads are getting started asynchronously, so a yield will yield to the next
211
+ # execution of the same function, which should call yield at the exact same point in the
212
+ # program, leading to a chain of yields effectively syncing the tasks (threads).
205
213
yield ()
206
214
return
207
215
end
208
216
209
- for (i, sym) in enumerate ((:x , :y , :z ))
210
- for f in (:blockidx , :blockdim , :threadidx , :griddim )
211
- fname = Symbol (string (f, ' _' , sym))
212
- @eval $ fname (state:: JLState ) = Int (state.$ f[$ i])
217
+ function GPUArrays. LocalMemory (state:: JLState , :: Type{T} , :: Val{N} , :: Val{C} ) where {T, N, C}
218
+ state. localmem_counter += 1
219
+ lmems = state. localmems[blockidx_x (state)]
220
+
221
+ # first invocation in block
222
+ if length (lmems) < state. localmem_counter
223
+ lmem = fill (zero (T), N)
224
+ push! (lmems, lmem)
225
+ return lmem
226
+ else
227
+ return lmems[state. localmem_counter]
213
228
end
214
229
end
215
230
216
- blas_module (:: JLArray ) = LinearAlgebra. BLAS
217
- blasbuffer (A:: JLArray ) = A. data
218
231
219
- # defining our own plan type is the easiest way to pass around the plans in FFTW interface
220
- # without ambiguities
232
+ # # device properties
221
233
222
- struct FFTPlan{T}
223
- p:: T
234
+ struct JLDevice end
235
+
236
+ GPUArrays. device (x:: JLArray ) = JLDevice ()
237
+
238
+ GPUArrays. threads (dev:: JLDevice ) = 256
239
+
240
+
241
+ # # linear algebra
242
+
243
+ using LinearAlgebra
244
+
245
+ GPUArrays. blas_module (:: JLArray ) = LinearAlgebra. BLAS
246
+ GPUArrays. blasbuffer (A:: JLArray ) = A. data
247
+
248
+
249
+
250
+ #
251
+ # AbstractDeviceArray interface
252
+ #
253
+
254
+ function GPUArrays. AbstractDeviceArray (ptr:: Array , shape:: NTuple{N, Integer} ) where N
255
+ reshape (ptr, shape)
256
+ end
257
+ function GPUArrays. AbstractDeviceArray (ptr:: Array , shape:: Vararg{Integer, N} ) where N
258
+ reshape (ptr, shape)
224
259
end
225
260
226
- AbstractFFTs. plan_fft (A:: JLArray ; kw_args... ) = FFTPlan (plan_fft (A. data; kw_args... ))
227
- AbstractFFTs. plan_fft! (A:: JLArray ; kw_args... ) = FFTPlan (plan_fft! (A. data; kw_args... ))
228
- AbstractFFTs. plan_bfft! (A:: JLArray ; kw_args... ) = FFTPlan (plan_bfft! (A. data; kw_args... ))
229
- AbstractFFTs. plan_bfft (A:: JLArray ; kw_args... ) = FFTPlan (plan_bfft (A. data; kw_args... ))
230
- AbstractFFTs. plan_ifft! (A:: JLArray ; kw_args... ) = FFTPlan (plan_ifft! (A. data; kw_args... ))
231
- AbstractFFTs. plan_ifft (A:: JLArray ; kw_args... ) = FFTPlan (plan_ifft (A. data; kw_args... ))
232
261
233
- function Base.:(* )(plan:: FFTPlan , A:: JLArray )
234
- x = plan. p * A. data
235
- JLArray (x)
262
+ # # indexing
263
+
264
+ for (i, sym) in enumerate ((:x , :y , :z ))
265
+ for f in (:blockidx , :blockdim , :threadidx , :griddim )
266
+ fname = Symbol (string (f, ' _' , sym))
267
+ @eval GPUArrays.$ fname (state:: JLState ) = Int (state.$ f[$ i])
268
+ end
269
+ end
270
+
236
271
end
0 commit comments