Skip to content

Commit e59378e

Browse files
authored
Merge pull request #233 from JuliaGPU/tb/simplify
Simplify gpu_call
2 parents 18d1512 + 0485343 commit e59378e

21 files changed

+406
-413
lines changed

docs/src/interface.md

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,83 +5,61 @@ implement the interfaces listed on this page. GPUArrays is design around having
55
different array types to represent a GPU array: one that only ever lives on the host, and
66
one that actually can be instantiated on the device (i.e. in kernels).
77

8-
## Host-side
98

10-
Your host-side array type should build on the `AbstractGPUArray` supertype:
9+
## Device functionality
1110

12-
```@docs
13-
AbstractGPUArray
14-
```
15-
16-
First of all, you should implement operations that are expected to be defined for any
17-
`AbstractArray` type. Refer to the Julia manual for more details, or look at the `JLArray`
18-
reference implementation.
19-
20-
To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you
21-
should provide implementations of the following interfaces:
11+
Several types and interfaces are related to the device and execution of code on it. First of
12+
all, you need to provide a type that represents your device and exposes some properties of
13+
it:
2214

2315
```@docs
24-
GPUArrays.unsafe_reinterpret
25-
```
26-
27-
### Devices
28-
29-
```@docs
30-
GPUArrays.device
31-
GPUArrays.synchronize
16+
GPUArrays.AbstractGPUDevice
17+
GPUArrays.threads
3218
```
3319

34-
### Execution
20+
Another important set of interfaces relates to executing code on the device:
3521

3622
```@docs
3723
GPUArrays.AbstractGPUBackend
38-
GPUArrays.backend
39-
```
40-
41-
```@docs
42-
GPUArrays._gpu_call
24+
GPUArrays.AbstractKernelContext
25+
GPUArrays.gpu_call
26+
GPUArrays.synchronize
27+
GPUArrays.thread_block_heuristic
4328
```
4429

45-
### Linear algebra
30+
Finally, you need to provide implementations of certain methods that will be executed on the
31+
device itself:
4632

4733
```@docs
48-
GPUArrays.blas_module
49-
GPUArrays.blasbuffer
34+
GPUArrays.AbstractDeviceArray
35+
GPUArrays.LocalMemory
36+
GPUArrays.synchronize_threads
37+
GPUArrays.blockidx
38+
GPUArrays.blockdim
39+
GPUArrays.threadidx
40+
GPUArrays.griddim
5041
```
5142

5243

53-
## Device-side
44+
## Host abstractions
5445

55-
To work with GPU memory on the device itself, e.g. within a kernel, we need a different
56-
type: Most functionality will behave differently when running on the GPU, e.g., accessing
57-
memory directly instead of copying it to the host. We should also take care not to call into
58-
any host library, such as the Julia runtime or the system's math library.
46+
You should provide an array type that builds on the `AbstractGPUArray` supertype:
5947

6048
```@docs
61-
AbstractDeviceArray
49+
AbstractGPUArray
6250
```
6351

64-
Your device array type should again implement the core elements of the `AbstractArray`
65-
interface, such as indexing and certain getters. Refer to the Julia manual for more details,
66-
or look at the `JLDeviceArray` reference implementation.
52+
First of all, you should implement operations that are expected to be defined for any
53+
`AbstractArray` type. Refer to the Julia manual for more details, or look at the `JLArray`
54+
reference implementation.
6755

68-
You should also provide implementations of several "GPU intrinsics". To make sure the
69-
correct implementation is called, the first argument to these intrinsics will be the kernel
70-
state object from before.
56+
To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you
57+
should provide implementations of the following interfaces:
7158

7259
```@docs
73-
GPUArrays.LocalMemory
74-
GPUArrays.synchronize_threads
75-
GPUArrays.blockidx_x
76-
GPUArrays.blockidx_y
77-
GPUArrays.blockidx_z
78-
GPUArrays.blockdim_x
79-
GPUArrays.blockdim_y
80-
GPUArrays.blockdim_z
81-
GPUArrays.threadidx_x
82-
GPUArrays.threadidx_y
83-
GPUArrays.threadidx_z
84-
GPUArrays.griddim_x
85-
GPUArrays.griddim_y
86-
GPUArrays.griddim_z
60+
GPUArrays.backend
61+
GPUArrays.device
62+
GPUArrays.unsafe_reinterpret
63+
GPUArrays.blas_module
64+
GPUArrays.blasbuffer
8765
```

src/GPUArrays.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@ using AbstractFFTs
1212

1313
using Adapt
1414

15-
# device array
15+
# device functionality
16+
include("device/device.jl")
17+
include("device/execution.jl")
18+
## executed on-device
1619
include("device/abstractarray.jl")
1720
include("device/indexing.jl")
21+
include("device/memory.jl")
1822
include("device/synchronization.jl")
1923

20-
# host array
24+
# host abstractions
2125
include("host/abstractarray.jl")
22-
include("host/devices.jl")
23-
include("host/execution.jl")
2426
include("host/construction.jl")
25-
## integrations and specialized functionality
27+
## integrations and specialized methods
2628
include("host/base.jl")
2729
include("host/indexing.jl")
2830
include("host/broadcast.jl")
@@ -32,7 +34,7 @@ include("host/random.jl")
3234
include("host/quirks.jl")
3335

3436
# CPU reference implementation
35-
include("array.jl")
37+
include("reference.jl")
3638

3739

3840
end # module

src/device/abstractarray.jl

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# on-device functionality
1+
# on-device array type
22

3-
export AbstractDeviceArray, @LocalMemory
3+
export AbstractDeviceArray
44

55

66
## device array
@@ -31,29 +31,3 @@ function Base.sum(A::AbstractDeviceArray{T}) where T
3131
end
3232
acc
3333
end
34-
35-
36-
## thread-local array
37-
38-
const shmem_counter = Ref{Int}(0)
39-
40-
"""
41-
Creates a local static memory shared inside one block.
42-
Equivalent to `__local` of OpenCL or `__shared__ (<variable>)` of CUDA.
43-
"""
44-
macro LocalMemory(state, T, N)
45-
id = (shmem_counter[] += 1)
46-
quote
47-
LocalMemory($(esc(state)), $(esc(T)), Val($(esc(N))), Val($id))
48-
end
49-
end
50-
51-
"""
52-
Creates a block local array pointer with `T` being the element type
53-
and `N` the length. Both T and N need to be static! C is a counter for
54-
approriately get the correct Local mem id in CUDAnative.
55-
This is an internal method which needs to be overloaded by the GPU Array backends
56-
"""
57-
function LocalMemory(state, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
58-
error("Not implemented") # COV_EXCL_LINE
59-
end

src/device/device.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# device management and properties
2+
3+
export AbstractGPUDevice
4+
5+
abstract type AbstractGPUDevice end
6+
7+
"""
8+
device(A::AbstractArray)
9+
10+
Gets the device associated to the Array `A`
11+
"""
12+
device(A::AbstractArray) = error("This array is not a GPU array") # COV_EXCL_LINE
13+
14+
"""
15+
Hardware threads of device
16+
"""
17+
threads(::AbstractGPUDevice) = error("Not implemented") # COV_EXCL_LINE

src/device/execution.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# kernel execution
2+
3+
export AbstractGPUBackend, AbstractKernelContext, gpu_call, synchronize, thread_blocks_heuristic
4+
5+
abstract type AbstractGPUBackend end
6+
7+
abstract type AbstractKernelContext end
8+
9+
"""
10+
backend(T::Type)
11+
backend(x)
12+
13+
Gets the GPUArrays back-end responsible for managing arrays of type `T`.
14+
"""
15+
backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
16+
backend(x) = backend(typeof(x))
17+
18+
"""
19+
gpu_call(kernel::Function, arg0, args...; kwargs...)
20+
21+
Executes `kernel` on the device that backs `arg` (see [`backend`](@ref)), passing along any
22+
arguments `args`. Additionally, the kernel will be passed the kernel execution context (see
23+
[`AbstractKernelContext`]), so its signature should be `(ctx::AbstractKernelContext, arg0,
24+
args...)`.
25+
26+
The keyword arguments `kwargs` are not passed to the function, but are interpreted on the
27+
host to influence how the kernel is executed. The following keyword arguments are supported:
28+
29+
- `target::AbstractArray`: specify which array object to use for determining execution
30+
properties (defaults to the first argument `arg0`).
31+
- `total_threads::Int`: how many threads should be launched _in total_. The actual number of
32+
threads and blocks is determined using a heuristic. Defaults to the length of `arg0` if
33+
no other keyword arguments that influence the launch configuration are specified.
34+
- `threads::Int` and `blocks::Int`: configure exactly how many threads and blocks are
35+
launched. This cannot be used in combination with the `total_threads` argument.
36+
"""
37+
function gpu_call(kernel::Base.Callable, args...;
38+
target::AbstractArray=first(args),
39+
total_threads::Union{Int,Nothing}=nothing,
40+
threads::Union{Int,Nothing}=nothing,
41+
blocks::Union{Int,Nothing}=nothing,
42+
kwargs...)
43+
# determine how many threads/blocks to launch
44+
if total_threads===nothing && threads===nothing && blocks===nothing
45+
total_threads = length(target)
46+
end
47+
if total_threads !== nothing
48+
if threads !== nothing || blocks !== nothing
49+
error("Cannot specify both total_threads and threads/blocks configuration")
50+
end
51+
blocks, threads = thread_blocks_heuristic(total_threads)
52+
else
53+
if threads === nothing
54+
threads = 1
55+
end
56+
if blocks === nothing
57+
blocks = 1
58+
end
59+
end
60+
61+
gpu_call(backend(target), kernel, args...; threads=threads, blocks=blocks, kwargs...)
62+
end
63+
64+
gpu_call(backend::AbstractGPUBackend, kernel, args...; kwargs...) = error("Not implemented") # COV_EXCL_LINE
65+
66+
"""
67+
synchronize(A::AbstractArray)
68+
69+
Blocks until all operations are finished on `A`
70+
"""
71+
function synchronize(A::AbstractArray)
72+
# fallback is a noop, for backends not needing synchronization. This
73+
# makes it easier to write generic code that also works for AbstractArrays
74+
end
75+
76+
function thread_blocks_heuristic(len::Integer)
77+
# TODO better threads default
78+
threads = clamp(len, 1, 256)
79+
blocks = max(ceil(Int, len / threads), 1)
80+
(blocks, threads)
81+
end

src/device/indexing.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,63 +5,63 @@ export global_size, synchronize_threads, linear_index
55

66
# thread indexing functions
77
for f in (:blockidx, :blockdim, :threadidx, :griddim)
8-
@eval $f(state)::Int = error("Not implemented") # COV_EXCL_LINE
8+
@eval $f(ctx::AbstractKernelContext)::Int = error("Not implemented") # COV_EXCL_LINE
99
@eval export $f
1010
end
1111

1212
"""
13-
global_size(state)
13+
global_size(ctx::AbstractKernelContext)
1414
1515
Global size == blockdim * griddim == total number of kernel execution
1616
"""
17-
@inline function global_size(state)
18-
griddim(state) * blockdim(state)
17+
@inline function global_size(ctx::AbstractKernelContext)
18+
griddim(ctx) * blockdim(ctx)
1919
end
2020

2121
"""
22-
linear_index(state)
22+
linear_index(ctx::AbstractKernelContext)
2323
2424
linear index corresponding to each kernel launch (in OpenCL equal to get_global_id).
2525
2626
"""
27-
@inline function linear_index(state)
28-
(blockidx(state) - 1) * blockdim(state) + threadidx(state)
27+
@inline function linear_index(ctx::AbstractKernelContext)
28+
(blockidx(ctx) - 1) * blockdim(ctx) + threadidx(ctx)
2929
end
3030

3131
"""
32-
linearidx(A, statesym = :state)
32+
linearidx(A, ctxsym = :ctx)
3333
3434
Macro form of `linear_index`, which calls return when out of bounds.
3535
So it can be used like this:
3636
3737
```julia
38-
function kernel(state, A)
39-
idx = @linear_index A state
38+
function kernel(ctx::AbstractKernelContext, A)
39+
idx = @linear_index A ctx
4040
# from here on it's save to index into A with idx
4141
@inbounds begin
4242
A[idx] = ...
4343
end
4444
end
4545
```
4646
"""
47-
macro linearidx(A, statesym = :state)
47+
macro linearidx(A, ctxsym = :ctx)
4848
quote
4949
x1 = $(esc(A))
50-
i1 = linear_index($(esc(statesym)))
50+
i1 = linear_index($(esc(ctxsym)))
5151
i1 > length(x1) && return
5252
i1
5353
end
5454
end
5555

5656
"""
57-
cartesianidx(A, statesym = :state)
57+
cartesianidx(A, ctxsym = :ctx)
5858
59-
Like [`@linearidx(A, statesym = :state)`](@ref), but returns an N-dimensional `NTuple{ndim(A), Int}` as index
59+
Like [`@linearidx(A, ctxsym = :ctx)`](@ref), but returns an N-dimensional `NTuple{ndim(A), Int}` as index
6060
"""
61-
macro cartesianidx(A, statesym = :state)
61+
macro cartesianidx(A, ctxsym = :ctx)
6262
quote
6363
x = $(esc(A))
64-
i2 = @linearidx(x, $(esc(statesym)))
64+
i2 = @linearidx(x, $(esc(ctxsym)))
6565
gpu_ind2sub(x, i2)
6666
end
6767
end

src/device/memory.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# on-device memory management
2+
3+
export @LocalMemory
4+
5+
6+
## thread-local array
7+
8+
const shmem_counter = Ref{Int}(0)
9+
10+
"""
11+
Creates a local static memory shared inside one block.
12+
Equivalent to `__local` of OpenCL or `__shared__ (<variable>)` of CUDA.
13+
"""
14+
macro LocalMemory(ctx, T, N)
15+
id = (shmem_counter[] += 1)
16+
quote
17+
LocalMemory($(esc(ctx)), $(esc(T)), Val($(esc(N))), Val($id))
18+
end
19+
end
20+
21+
"""
22+
Creates a block local array pointer with `T` being the element type
23+
and `N` the length. Both T and N need to be static! C is a counter for
24+
approriately get the correct Local mem id in CUDAnative.
25+
This is an internal method which needs to be overloaded by the GPU Array backends
26+
"""
27+
function LocalMemory(ctx, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
28+
error("Not implemented") # COV_EXCL_LINE
29+
end

0 commit comments

Comments
 (0)