1
- # Very simple Julia backend which is just for testing the implementation
2
- # and can be used as a reference implementation
1
+ # Very simple Julia back-end which is just for testing the implementation and can be used as
2
+ # a reference implementation
3
+
4
+
5
+ # # construction
3
6
4
7
struct JLArray{T, N} <: GPUArray{T, N}
5
8
data:: Array{T, N}
12
15
13
16
JLArray (data:: AbstractArray{T, N} , size:: Dims{N} ) where {T,N} = JLArray {T,N} (data, size)
14
17
18
+ (:: Type{<: JLArray{T}} )(x:: AbstractArray ) where T = JLArray (convert (Array{T}, x), size (x))
19
+
20
+ function JLArray {T, N} (size:: NTuple{N, Integer} ) where {T, N}
21
+ JLArray {T, N} (Array {T, N} (undef, size), size)
22
+ end
23
+
24
+
25
+ # # getters
26
+
27
+ size (x:: JLArray ) = x. size
28
+
29
+ pointer (x:: JLArray ) = pointer (x. data)
30
+
31
+
32
+ # # I/O
33
+
15
34
Base. show (io:: IO , x:: JLArray ) = show (io, collect (x))
16
35
Base. show (io:: IO , x:: LinearAlgebra.Adjoint{<:Any,<:JLArray} ) = show (io, LinearAlgebra. adjoint (collect (x. parent)))
17
36
Base. show (io:: IO , x:: LinearAlgebra.Transpose{<:Any,<:JLArray} ) = show (io, LinearAlgebra. transpose (collect (x. parent)))
@@ -20,15 +39,16 @@ Base.show(io::IO, ::MIME"text/plain", x::JLArray) = show(io, MIME"text/plain"(),
20
39
Base. show (io:: IO , :: MIME"text/plain" , x:: LinearAlgebra.Adjoint{<:Any,<:JLArray} ) = show (io, MIME " text/plain" (), LinearAlgebra. adjoint (collect (x. parent)))
21
40
Base. show (io:: IO , :: MIME"text/plain" , x:: LinearAlgebra.Transpose{<:Any,<:JLArray} ) = show (io, MIME " text/plain" (), LinearAlgebra. transpose (collect (x. parent)))
22
41
42
+
43
+ # # other
44
+
23
45
"""
24
46
Thread group local memory
25
47
"""
26
48
struct LocalMem{N, T}
27
49
x:: NTuple{N, Vector{T}}
28
50
end
29
51
30
- size (x:: JLArray ) = x. size
31
- pointer (x:: JLArray ) = pointer (x. data)
32
52
to_device (state, x:: JLArray ) = x. data
33
53
to_device (state, x:: Tuple ) = to_device .(Ref (state), x)
34
54
to_device (state, x:: RefValue{<: JLArray} ) = RefValue (to_device (state, x[]))
@@ -40,12 +60,6 @@ to_blocks(state, x) = x
40
60
# unpacks local memory for each block
41
61
to_blocks (state, x:: LocalMem ) = x. x[blockidx_x (state)]
42
62
43
- (:: Type{<: JLArray{T}} )(x:: AbstractArray ) where T = JLArray (convert (Array{T}, x), size (x))
44
-
45
- function JLArray {T, N} (size:: NTuple{N, Integer} ) where {T, N}
46
- JLArray {T, N} (Array {T, N} (undef, size), size)
47
- end
48
-
49
63
similar (:: Type{<: JLArray} , :: Type{T} , size:: Base.Dims{N} ) where {T, N} = JLArray {T, N} (size)
50
64
51
65
function unsafe_reinterpret (:: Type{T} , A:: JLArray{ET} , size:: NTuple{N, Integer} ) where {T, ET, N}
@@ -131,7 +145,8 @@ function _gpu_call(f, A::JLArray, args::Tuple, blocks_threads::Tuple{T, T}) wher
131
145
block_args = to_blocks .(Ref (state), device_args)
132
146
for threadidx in CartesianIndices (threads)
133
147
thread_state = JLState (state, threadidx. I)
134
- tasks[threadidx] = @async f (thread_state, block_args... )
148
+ tasks[threadidx] = @async @allowscalar f (thread_state, block_args... )
149
+ # TODO : @async obfuscates the trace to any exception which happens during f
135
150
end
136
151
for t in tasks
137
152
fetch (t)
@@ -146,7 +161,6 @@ device(x::JLArray) = JLDevice()
146
161
threads (dev:: JLDevice ) = 256
147
162
blocks (dev:: JLDevice ) = (256 , 256 , 256 )
148
163
149
-
150
164
@inline function synchronize_threads (:: JLState )
151
165
#=
152
166
All threads are getting started asynchronously,so a yield will
168
182
blas_module (:: JLArray ) = LinearAlgebra. BLAS
169
183
blasbuffer (A:: JLArray ) = A. data
170
184
185
+ # defining our own plan type is the easiest way to pass around the plans in Base interface
186
+ # without ambiguities
171
187
172
- # defining our own plan type is the easiest way to pass around the plans in Base interface without ambiguities
173
188
struct FFTPlan{T}
174
189
p:: T
175
190
end
@@ -192,7 +207,6 @@ function plan_ifft(A::JLArray; kw_args...)
192
207
FFTPlan (plan_ifft (A. data; kw_args... ))
193
208
end
194
209
195
-
196
210
function * (plan:: FFTPlan , A:: JLArray )
197
211
x = plan. p * A. data
198
212
JLArray (x)
0 commit comments