Skip to content

Commit d5a199b

Browse files
committed
Clean up JLArray reference implementation.
1 parent 712ee5e commit d5a199b

File tree

4 files changed

+119
-80
lines changed

4 files changed

+119
-80
lines changed

src/array.jl

Lines changed: 113 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# reference implementation of the GPUArrays interfaces
1+
# reference implementation of a CPU-based array type
2+
3+
module JLArrays
4+
5+
using GPUArrays
26

37
export JLArray
48

@@ -12,7 +16,11 @@ struct JLArray{T, N} <: AbstractGPUArray{T, N}
1216
end
1317

1418

15-
## construction
19+
#
20+
# AbstractArray interface
21+
#
22+
23+
## typical constructors
1624

1725
# type and dimensionality specified, accepting dims as tuples of Ints
1826
JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} =
@@ -29,7 +37,6 @@ JLArray{T}(::UndefInitializer, dims::Integer...) where {T} =
2937
# empty vector constructor
3038
JLArray{T,1}() where {T} = JLArray{T,1}(undef, 0)
3139

32-
3340
Base.similar(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(undef, size(a))
3441
Base.similar(a::JLArray{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
3542
Base.similar(a::JLArray, ::Type{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
@@ -64,6 +71,8 @@ Base.convert(::Type{T}, x::T) where T <: JLArray = x
6471

6572
## broadcast
6673

74+
using Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
75+
6776
BroadcastStyle(::Type{<:JLArray}) = ArrayStyle{JLArray}()
6877

6978
function Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}) where T
@@ -72,29 +81,8 @@ end
7281

7382
Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}, dims...) where {T} = JLArray{T}(undef, dims...)
7483

75-
## gpuarray interface
76-
77-
struct JLBackend <: GPUBackend end
78-
backend(::Type{<:JLArray}) = JLBackend()
79-
80-
"""
81-
Thread group local memory
82-
"""
83-
struct LocalMem{N, T}
84-
x::NTuple{N, Vector{T}}
85-
end
8684

87-
to_device(state, x::JLArray) = x.data
88-
to_device(state, x::Tuple) = to_device.(Ref(state), x)
89-
to_device(state, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(state, x[]))
90-
to_device(state, x) = x
91-
92-
to_blocks(state, x) = x
93-
# unpacks local memory for each block
94-
to_blocks(state, x::LocalMem) = x.x[blockidx_x(state)]
95-
96-
unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
97-
reshape(reinterpret(T, A.data), size)
85+
## memory operations
9886

9987
function Base.copyto!(dest::Array{T}, d_offset::Integer,
10088
source::JLArray{T}, s_offset::Integer,
@@ -103,6 +91,7 @@ function Base.copyto!(dest::Array{T}, d_offset::Integer,
10391
@boundscheck checkbounds(source, s_offset+amount-1)
10492
copyto!(dest, d_offset, source.data, s_offset, amount)
10593
end
94+
10695
function Base.copyto!(dest::JLArray{T}, d_offset::Integer,
10796
source::Array{T}, s_offset::Integer,
10897
amount::Integer) where T
@@ -111,6 +100,7 @@ function Base.copyto!(dest::JLArray{T}, d_offset::Integer,
111100
copyto!(dest.data, d_offset, source, s_offset, amount)
112101
dest
113102
end
103+
114104
function Base.copyto!(dest::JLArray{T}, d_offset::Integer,
115105
source::JLArray{T}, s_offset::Integer,
116106
amount::Integer) where T
@@ -120,6 +110,45 @@ function Base.copyto!(dest::JLArray{T}, d_offset::Integer,
120110
dest
121111
end
122112

113+
## fft
114+
115+
using AbstractFFTs
116+
117+
# defining our own plan type is the easiest way to pass around the plans in FFTW interface
118+
# without ambiguities
119+
120+
struct FFTPlan{T}
121+
p::T
122+
end
123+
124+
AbstractFFTs.plan_fft(A::JLArray; kw_args...) = FFTPlan(plan_fft(A.data; kw_args...))
125+
AbstractFFTs.plan_fft!(A::JLArray; kw_args...) = FFTPlan(plan_fft!(A.data; kw_args...))
126+
AbstractFFTs.plan_bfft!(A::JLArray; kw_args...) = FFTPlan(plan_bfft!(A.data; kw_args...))
127+
AbstractFFTs.plan_bfft(A::JLArray; kw_args...) = FFTPlan(plan_bfft(A.data; kw_args...))
128+
AbstractFFTs.plan_ifft!(A::JLArray; kw_args...) = FFTPlan(plan_ifft!(A.data; kw_args...))
129+
AbstractFFTs.plan_ifft(A::JLArray; kw_args...) = FFTPlan(plan_ifft(A.data; kw_args...))
130+
131+
function Base.:(*)(plan::FFTPlan, A::JLArray)
132+
x = plan.p * A.data
133+
JLArray(x)
134+
end
135+
136+
137+
138+
#
139+
# AbstractGPUArray interface
140+
#
141+
142+
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
143+
reshape(reinterpret(T, A.data), size)
144+
145+
146+
## execution
147+
148+
struct JLBackend <: AbstractGPUBackend end
149+
150+
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
151+
123152
mutable struct JLState{N}
124153
blockdim::NTuple{N, Int}
125154
griddim::NTuple{N, Int}
@@ -148,27 +177,12 @@ function JLState(state::JLState{N}, threadidx::NTuple{N}) where N
148177
)
149178
end
150179

151-
function LocalMemory(state::JLState, ::Type{T}, ::Val{N}, ::Val{C}) where {T, N, C}
152-
state.localmem_counter += 1
153-
lmems = state.localmems[blockidx_x(state)]
154-
# first invocation in block
155-
if length(lmems) < state.localmem_counter
156-
lmem = fill(zero(T), N)
157-
push!(lmems, lmem)
158-
return lmem
159-
else
160-
return lmems[state.localmem_counter]
161-
end
162-
end
163-
164-
function AbstractDeviceArray(ptr::Array, shape::NTuple{N, Integer}) where N
165-
reshape(ptr, shape)
166-
end
167-
function AbstractDeviceArray(ptr::Array, shape::Vararg{Integer, N}) where N
168-
reshape(ptr, shape)
169-
end
180+
to_device(state, x::JLArray) = x.data
181+
to_device(state, x::Tuple) = to_device.(Ref(state), x)
182+
to_device(state, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(state, x[]))
183+
to_device(state, x) = x
170184

171-
function _gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
185+
function GPUArrays._gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
172186
blocks, threads = blocks_threads
173187
idx = ntuple(i-> 1, length(blocks))
174188
blockdim = blocks
@@ -177,10 +191,9 @@ function _gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T})
177191
tasks = Array{Task}(undef, threads...)
178192
for blockidx in CartesianIndices(blockdim)
179193
state.blockidx = blockidx.I
180-
block_args = to_blocks.(Ref(state), device_args)
181194
for threadidx in CartesianIndices(threads)
182195
thread_state = JLState(state, threadidx.I)
183-
tasks[threadidx] = @async @allowscalar f(thread_state, block_args...)
196+
tasks[threadidx] = @async @allowscalar f(thread_state, device_args...)
184197
# TODO: @async obfuscates the trace to any exception which happens during f
185198
end
186199
for t in tasks
@@ -190,47 +203,69 @@ function _gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T})
190203
return
191204
end
192205

193-
# "intrinsics"
194-
struct JLDevice end
195-
device(x::JLArray) = JLDevice()
196-
threads(dev::JLDevice) = 256
197-
198-
@inline function synchronize_threads(::JLState)
199-
#=
200-
All threads are getting started asynchronously,so a yield will
201-
yield to the next execution of the same function, which should call yield
202-
at the exact same point in the program, leading to a chain of yields effectively syncing
203-
the tasks (threads).
204-
=#
206+
207+
## gpu intrinsics
208+
209+
@inline function GPUArrays.synchronize_threads(::JLState)
210+
# All threads are getting started asynchronously, so a yield will yield to the next
211+
# execution of the same function, which should call yield at the exact same point in the
212+
# program, leading to a chain of yields effectively syncing the tasks (threads).
205213
yield()
206214
return
207215
end
208216

209-
for (i, sym) in enumerate((:x, :y, :z))
210-
for f in (:blockidx, :blockdim, :threadidx, :griddim)
211-
fname = Symbol(string(f, '_', sym))
212-
@eval $fname(state::JLState) = Int(state.$f[$i])
217+
function GPUArrays.LocalMemory(state::JLState, ::Type{T}, ::Val{N}, ::Val{C}) where {T, N, C}
218+
state.localmem_counter += 1
219+
lmems = state.localmems[blockidx_x(state)]
220+
221+
# first invocation in block
222+
if length(lmems) < state.localmem_counter
223+
lmem = fill(zero(T), N)
224+
push!(lmems, lmem)
225+
return lmem
226+
else
227+
return lmems[state.localmem_counter]
213228
end
214229
end
215230

216-
blas_module(::JLArray) = LinearAlgebra.BLAS
217-
blasbuffer(A::JLArray) = A.data
218231

219-
# defining our own plan type is the easiest way to pass around the plans in FFTW interface
220-
# without ambiguities
232+
## device properties
221233

222-
struct FFTPlan{T}
223-
p::T
234+
struct JLDevice end
235+
236+
GPUArrays.device(x::JLArray) = JLDevice()
237+
238+
GPUArrays.threads(dev::JLDevice) = 256
239+
240+
241+
## linear algebra
242+
243+
using LinearAlgebra
244+
245+
GPUArrays.blas_module(::JLArray) = LinearAlgebra.BLAS
246+
GPUArrays.blasbuffer(A::JLArray) = A.data
247+
248+
249+
250+
#
251+
# AbstractDeviceArray interface
252+
#
253+
254+
function GPUArrays.AbstractDeviceArray(ptr::Array, shape::NTuple{N, Integer}) where N
255+
reshape(ptr, shape)
256+
end
257+
function GPUArrays.AbstractDeviceArray(ptr::Array, shape::Vararg{Integer, N}) where N
258+
reshape(ptr, shape)
224259
end
225260

226-
AbstractFFTs.plan_fft(A::JLArray; kw_args...) = FFTPlan(plan_fft(A.data; kw_args...))
227-
AbstractFFTs.plan_fft!(A::JLArray; kw_args...) = FFTPlan(plan_fft!(A.data; kw_args...))
228-
AbstractFFTs.plan_bfft!(A::JLArray; kw_args...) = FFTPlan(plan_bfft!(A.data; kw_args...))
229-
AbstractFFTs.plan_bfft(A::JLArray; kw_args...) = FFTPlan(plan_bfft(A.data; kw_args...))
230-
AbstractFFTs.plan_ifft!(A::JLArray; kw_args...) = FFTPlan(plan_ifft!(A.data; kw_args...))
231-
AbstractFFTs.plan_ifft(A::JLArray; kw_args...) = FFTPlan(plan_ifft(A.data; kw_args...))
232261

233-
function Base.:(*)(plan::FFTPlan, A::JLArray)
234-
x = plan.p * A.data
235-
JLArray(x)
262+
## indexing
263+
264+
for (i, sym) in enumerate((:x, :y, :z))
265+
for f in (:blockidx, :blockdim, :threadidx, :griddim)
266+
fname = Symbol(string(f, '_', sym))
267+
@eval GPUArrays.$fname(state::JLState) = Int(state.$f[$i])
268+
end
269+
end
270+
236271
end

src/host/execution.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# kernel execution
22

3-
export gpu_call, synchronize, thread_blocks_heuristic
3+
export AbstractGPUBackend, gpu_call, synchronize, thread_blocks_heuristic
44

5-
abstract type GPUBackend end
5+
abstract type AbstractGPUBackend end
66

77
backend(::Type{T}) where T = error("Can't choose GPU backend for $T")
88

src/host/indexing.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# host-level indexing
22

3+
export allowscalar, @allowscalar, assertscalar
4+
5+
36
# mechanism to disallow scalar operations
47

58
const scalar_allowed = Ref(true)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ using GPUArrays, Test
33
include("testsuite.jl")
44

55
@testset "JLArray" begin
6+
using GPUArrays.JLArrays
67
TestSuite.test(JLArray)
78
end

0 commit comments

Comments
 (0)