Skip to content

Commit 3f2ec79

Browse files
committed
Make AbstractDeviceArray a proper type.
1 parent b4d2f78 commit 3f2ec79

File tree

5 files changed

+99
-84
lines changed

5 files changed

+99
-84
lines changed

src/GPUArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Adapt
1515
# device array
1616
include("device/abstractarray.jl")
1717
include("device/indexing.jl")
18-
include("device/gpu.jl")
18+
include("device/synchronization.jl")
1919

2020
# host array
2121
include("host/abstractarray.jl")

src/array.jl

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@ using GPUArrays
66

77
export JLArray
88

9+
10+
#
11+
# Host array
12+
#
13+
14+
# the definition of a host array type, implementing different Base interfaces
15+
# to make it function properly and behave like the Base Array type.
16+
917
struct JLArray{T, N} <: AbstractGPUArray{T, N}
1018
data::Array{T, N}
1119
dims::Dims{N}
@@ -15,12 +23,7 @@ struct JLArray{T, N} <: AbstractGPUArray{T, N}
1523
end
1624
end
1725

18-
19-
#
20-
# AbstractArray interface
21-
#
22-
23-
## typical constructors
26+
## constructors
2427

2528
# type and dimensionality specified, accepting dims as tuples of Ints
2629
JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} =
@@ -139,6 +142,8 @@ end
139142
# AbstractGPUArray interface
140143
#
141144

145+
# implementation of GPUArrays-specific interfaces
146+
142147
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
143148
reshape(reinterpret(T, A.data), size)
144149

@@ -177,7 +182,7 @@ function JLState(state::JLState{N}, threadidx::NTuple{N}) where N
177182
)
178183
end
179184

180-
to_device(state, x::JLArray) = x.data
185+
to_device(state, x::JLArray{T,N}) where {T,N} = JLDeviceArray{T,N}(x.data, x.dims)
181186
to_device(state, x::Tuple) = to_device.(Ref(state), x)
182187
to_device(state, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(state, x[]))
183188
to_device(state, x) = x
@@ -205,31 +210,6 @@ function GPUArrays._gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tup
205210
end
206211

207212

208-
## gpu intrinsics
209-
210-
@inline function GPUArrays.synchronize_threads(::JLState)
211-
# All threads are getting started asynchronously, so a yield will yield to the next
212-
# execution of the same function, which should call yield at the exact same point in the
213-
# program, leading to a chain of yields effectively syncing the tasks (threads).
214-
yield()
215-
return
216-
end
217-
218-
function GPUArrays.LocalMemory(state::JLState, ::Type{T}, ::Val{N}, ::Val{C}) where {T, N, C}
219-
state.localmem_counter += 1
220-
lmems = state.localmems[blockidx_x(state)]
221-
222-
# first invocation in block
223-
if length(lmems) < state.localmem_counter
224-
lmem = fill(zero(T), N)
225-
push!(lmems, lmem)
226-
return lmem
227-
else
228-
return lmems[state.localmem_counter]
229-
end
230-
end
231-
232-
233213
## device properties
234214

235215
struct JLDevice end
@@ -249,24 +229,65 @@ GPUArrays.blasbuffer(A::JLArray) = A.data
249229

250230

251231
#
252-
# AbstractDeviceArray interface
232+
# Device array
253233
#
254234

255-
function GPUArrays.AbstractDeviceArray(ptr::Array, shape::NTuple{N, Integer}) where N
256-
reshape(ptr, shape)
235+
# definition of a minimal device array type that supports the subset of operations
236+
# that are used in GPUArrays kernels
237+
238+
struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
239+
data::Array{T, N}
240+
dims::Dims{N}
241+
242+
function JLDeviceArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
243+
new(data, dims)
244+
end
257245
end
258-
function GPUArrays.AbstractDeviceArray(ptr::Array, shape::Vararg{Integer, N}) where N
259-
reshape(ptr, shape)
246+
247+
function GPUArrays.LocalMemory(state::JLState, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
248+
state.localmem_counter += 1
249+
lmems = state.localmems[blockidx_x(state)]
250+
251+
# first invocation in block
252+
data = if length(lmems) < state.localmem_counter
253+
lmem = fill(zero(T), dims)
254+
push!(lmems, lmem)
255+
lmem
256+
else
257+
lmems[state.localmem_counter]
258+
end
259+
260+
N = length(dims)
261+
JLDeviceArray{T,N}(data, tuple(dims...))
260262
end
261263

262264

265+
## array interface
266+
267+
Base.size(x::JLDeviceArray) = x.dims
268+
269+
263270
## indexing
264271

272+
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(A.data, index)
273+
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(A.data, x, index)
274+
265275
for (i, sym) in enumerate((:x, :y, :z))
266276
for f in (:blockidx, :blockdim, :threadidx, :griddim)
267277
fname = Symbol(string(f, '_', sym))
268278
@eval GPUArrays.$fname(state::JLState) = Int(state.$f[$i])
269279
end
270280
end
271281

282+
283+
## synchronization
284+
285+
@inline function GPUArrays.synchronize_threads(::JLState)
286+
# All threads are getting started asynchronously, so a yield will yield to the next
287+
# execution of the same function, which should call yield at the exact same point in the
288+
# program, leading to a chain of yields effectively syncing the tasks (threads).
289+
yield()
290+
return
291+
end
292+
272293
end

src/device/abstractarray.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# on-device functionality
22

3-
export AbstractDeviceArray
3+
export AbstractDeviceArray, @LocalMemory
44

55

66
## device array
@@ -24,3 +24,29 @@ function Base.sum(A::AbstractDeviceArray{T}) where T
2424
end
2525
acc
2626
end
27+
28+
29+
## thread-local array
30+
31+
const shmem_counter = Ref{Int}(0)
32+
33+
"""
34+
Creates a local static memory shared inside one block.
35+
Equivalent to `__local` of OpenCL or `__shared__ (<variable>)` of CUDA.
36+
"""
37+
macro LocalMemory(state, T, N)
38+
id = (shmem_counter[] += 1)
39+
quote
40+
LocalMemory($(esc(state)), $(esc(T)), Val($(esc(N))), Val($id))
41+
end
42+
end
43+
44+
"""
45+
Creates a block local array pointer with `T` being the element type
46+
and `N` the length. Both T and N need to be static! C is a counter for
47+
approriately get the correct Local mem id in CUDAnative.
48+
This is an internal method which needs to be overloaded by the GPU Array backends
49+
"""
50+
function LocalMemory(state, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
51+
error("Not implemented") # COV_EXCL_LINE
52+
end

src/device/gpu.jl

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/device/synchronization.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# synchronization
2+
3+
export synchronize_threads
4+
5+
"""
6+
synchronize_threads(state)
7+
8+
in CUDA terms `__synchronize`
9+
in OpenCL terms: `barrier(CLK_LOCAL_MEM_FENCE)`
10+
"""
11+
function synchronize_threads(state)
12+
error("Not implemented") # COV_EXCL_LINE
13+
end

0 commit comments

Comments
 (0)