Skip to content

Commit f641b2f

Browse files
committed
Reshuffle and redocument some code.
1 parent 0868d9f commit f641b2f

File tree

4 files changed

+119
-115
lines changed

4 files changed

+119
-115
lines changed

src/abstract_gpu_interface.jl

Lines changed: 21 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,3 @@
1-
#=
2-
Abstraction over the GPU thread indexing functions.
3-
Uses CUDA like names
4-
=#
5-
for sym in (:x, :y, :z)
6-
for f in (:blockidx, :blockdim, :threadidx, :griddim)
7-
fname = Symbol(string(f, '_', sym))
8-
@eval $fname(state)::Int = error("Not implemented")
9-
@eval export $fname
10-
end
11-
end
12-
13-
14-
15-
"""
16-
synchronize_threads(state)
17-
18-
in CUDA terms `__synchronize`
19-
in OpenCL terms: `barrier(CLK_LOCAL_MEM_FENCE)`
20-
"""
21-
function synchronize_threads(state)
22-
error("Not implemented")
23-
end
24-
25-
26-
"""
27-
linear_index(state)
28-
29-
linear index corresponding to each kernel launch (in OpenCL equal to get_global_id).
30-
31-
"""
32-
@inline function linear_index(state)
33-
(blockidx_x(state) - 1) * blockdim_x(state) + threadidx_x(state)
34-
end
35-
36-
"""
37-
linearidx(A, statesym = :state)
38-
39-
Macro form of `linear_index`, which calls return when out of bounds.
40-
So it can be used like this:
41-
42-
```julia
43-
function kernel(state, A)
44-
idx = @linear_index A state
45-
# from here on it's save to index into A with idx
46-
@inbounds begin
47-
A[idx] = ...
48-
end
49-
end
50-
```
51-
"""
52-
macro linearidx(A, statesym = :state)
53-
quote
54-
x1 = $(esc(A))
55-
i1 = linear_index($(esc(statesym)))
56-
i1 > length(x1) && return
57-
i1
58-
end
59-
end
60-
61-
62-
"""
63-
cartesianidx(A, statesym = :state)
64-
65-
Like [`@linearidx(A, statesym = :state)`](@ref), but returns an N-dimensional `NTuple{ndim(A), Int}` as index
66-
"""
67-
macro cartesianidx(A, statesym = :state)
68-
quote
69-
x = $(esc(A))
70-
i2 = @linearidx(x, $(esc(statesym)))
71-
gpu_ind2sub(x, i2)
72-
end
73-
end
74-
75-
"""
76-
global_size(state)
77-
78-
Global size == blockdim * griddim == total number of kernel execution
79-
"""
80-
@inline function global_size(state)
81-
# TODO nd version
82-
griddim_x(state) * blockdim_x(state)
83-
end
84-
85-
"""
86-
device(A::AbstractArray)
87-
88-
Gets the device associated to the Array `A`
89-
"""
90-
function device(A::AbstractArray)
91-
# fallback is a noop, for backends not needing synchronization. This
92-
# makes it easier to write generic code that also works for AbstractArrays
93-
end
94-
95-
"""
96-
synchronize(A::AbstractArray)
97-
98-
Blocks until all operations are finished on `A`
99-
"""
100-
function synchronize(A::AbstractArray)
101-
# fallback is a noop, for backends not needing synchronization. This
102-
# makes it easier to write generic code that also works for AbstractArrays
103-
end
104-
#
105-
# @inline function synchronize_threads(state)
106-
# CUDAnative.__syncthreads()
107-
# end
108-
1091
abstract type GPUBackend end
1102
backend(::Type{T}) where T = error("Can't choose GPU backend for $T")
1113

@@ -153,3 +45,24 @@ end
15345

15446
# Internal GPU call function, that needs to be overloaded by the backends.
15547
_gpu_call(::Any, f, A, args, thread_blocks) = error("Not implemented")
48+
49+
50+
"""
51+
device(A::AbstractArray)
52+
53+
Gets the device associated to the Array `A`
54+
"""
55+
function device(A::AbstractArray)
56+
# fallback is a noop, for backends not needing synchronization. This
57+
# makes it easier to write generic code that also works for AbstractArrays
58+
end
59+
60+
"""
61+
synchronize(A::AbstractArray)
62+
63+
Blocks until all operations are finished on `A`
64+
"""
65+
function synchronize(A::AbstractArray)
66+
# fallback is a noop, for backends not needing synchronization. This
67+
# makes it easier to write generic code that also works for AbstractArrays
68+
end

src/abstractarray.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# Dense GPU Array
1+
# core definition of the GPUArray type
2+
23
abstract type GPUArray{T, N} <: DenseArray{T, N} end
34

45
# Sampler type that acts like a texture/image and allows interpolated access

src/blas.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11

2-
# Interface that needs to be overwritten by backend
3-
# Slightly difference behavior from buffer, since not all blas backends work directly with
4-
# the gpu array buffer
2+
# calls to standard BLAS interfaces
3+
4+
## interface
5+
56
function blas_module(A)
67
error("$(typeof(A)) doesn't support BLAS operations")
78
end
89
function blasbuffer(A)
910
error("$(typeof(A)) doesn't support BLAS operations")
1011
end
1112

13+
14+
## operations
15+
1216
for elty in (Float32, Float64, ComplexF32, ComplexF64)
1317
T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty
1418
@eval begin
@@ -53,7 +57,6 @@ function LinearAlgebra.rmul!(X::GPUArray{T}, s::Number) where T <: Union{Float32
5357
X
5458
end
5559

56-
5760
for elty in (Float32, Float64, ComplexF32, ComplexF64)
5861
T = VERSION >= v"1.3.0-alpha.115" ? :(Union{($elty), Bool}) : elty
5962
@eval begin
@@ -76,7 +79,6 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64)
7679
end
7780
end
7881

79-
8082
for elty in (Float32, Float64, ComplexF32, ComplexF64)
8183
@eval begin
8284
function BLAS.axpy!(
@@ -92,7 +94,6 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64)
9294
end
9395
end
9496

95-
9697
for elty in (Float32, Float64, ComplexF32, ComplexF64)
9798
@eval begin
9899
function BLAS.gbmv!(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::GPUMatrix{$elty}, X::GPUVector{$elty}, beta::($elty), Y::GPUVector{$elty})

src/ondevice.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,90 @@
1+
# functionality for vendor-agnostic kernels
2+
3+
## indexing
4+
5+
# thread indexing functions
6+
for sym in (:x, :y, :z)
7+
for f in (:blockidx, :blockdim, :threadidx, :griddim)
8+
fname = Symbol(string(f, '_', sym))
9+
@eval $fname(state)::Int = error("Not implemented")
10+
@eval export $fname
11+
end
12+
end
13+
14+
"""
15+
global_size(state)
16+
17+
Global size == blockdim * griddim == total number of kernel execution
18+
"""
19+
@inline function global_size(state)
20+
# TODO nd version
21+
griddim_x(state) * blockdim_x(state)
22+
end
23+
24+
"""
25+
linear_index(state)
26+
27+
linear index corresponding to each kernel launch (in OpenCL equal to get_global_id).
28+
29+
"""
30+
@inline function linear_index(state)
31+
(blockidx_x(state) - 1) * blockdim_x(state) + threadidx_x(state)
32+
end
33+
34+
"""
35+
linearidx(A, statesym = :state)
36+
37+
Macro form of `linear_index`, which calls return when out of bounds.
38+
So it can be used like this:
39+
40+
```julia
41+
function kernel(state, A)
42+
idx = @linear_index A state
43+
# from here on it's save to index into A with idx
44+
@inbounds begin
45+
A[idx] = ...
46+
end
47+
end
48+
```
49+
"""
50+
macro linearidx(A, statesym = :state)
51+
quote
52+
x1 = $(esc(A))
53+
i1 = linear_index($(esc(statesym)))
54+
i1 > length(x1) && return
55+
i1
56+
end
57+
end
58+
59+
"""
60+
cartesianidx(A, statesym = :state)
61+
62+
Like [`@linearidx(A, statesym = :state)`](@ref), but returns an N-dimensional `NTuple{ndim(A), Int}` as index
63+
"""
64+
macro cartesianidx(A, statesym = :state)
65+
quote
66+
x = $(esc(A))
67+
i2 = @linearidx(x, $(esc(statesym)))
68+
gpu_ind2sub(x, i2)
69+
end
70+
end
71+
72+
73+
## synchronization
74+
75+
"""
76+
synchronize_threads(state)
77+
78+
in CUDA terms `__synchronize`
79+
in OpenCL terms: `barrier(CLK_LOCAL_MEM_FENCE)`
80+
"""
81+
function synchronize_threads(state)
82+
error("Not implemented")
83+
end
84+
85+
86+
## device array
87+
188
abstract type AbstractDeviceArray{T, N} <: AbstractArray{T, N} end
289

390
Base.IndexStyle(::AbstractDeviceArray) = IndexLinear()
@@ -19,6 +106,8 @@ function Base.sum(A::AbstractDeviceArray{T}) where T
19106
end
20107

21108

109+
## device memory
110+
22111
const shmem_counter = Ref{Int}(0)
23112

24113
"""

0 commit comments

Comments
 (0)