Skip to content

Commit b841f66

Browse files
committed
Shuffle some more around, rename State to Context.
1 parent 11e6179 commit b841f66

17 files changed

+254
-253
lines changed

src/GPUArrays.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@ using AbstractFFTs
1212

1313
using Adapt
1414

15-
# device array
15+
# device functionality
16+
include("device/device.jl")
17+
include("device/execution.jl")
18+
## on-device
1619
include("device/abstractarray.jl")
1720
include("device/indexing.jl")
21+
include("device/memory.jl")
1822
include("device/synchronization.jl")
1923

20-
# host array
24+
# host array abstraction
2125
include("host/abstractarray.jl")
22-
include("host/devices.jl")
23-
include("host/execution.jl")
2426
include("host/construction.jl")
25-
## integrations and specialized functionality
27+
## integrations and specialized methods
2628
include("host/base.jl")
2729
include("host/indexing.jl")
2830
include("host/broadcast.jl")
@@ -32,7 +34,7 @@ include("host/random.jl")
3234
include("host/quirks.jl")
3335

3436
# CPU reference implementation
35-
include("array.jl")
37+
include("reference.jl")
3638

3739

3840
end # module

src/device/abstractarray.jl

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# on-device functionality
1+
# on-device array type
22

3-
export AbstractDeviceArray, @LocalMemory
3+
export AbstractDeviceArray
44

55

66
## device array
@@ -31,29 +31,3 @@ function Base.sum(A::AbstractDeviceArray{T}) where T
3131
end
3232
acc
3333
end
34-
35-
36-
## thread-local array
37-
38-
const shmem_counter = Ref{Int}(0)
39-
40-
"""
41-
Creates a local static memory shared inside one block.
42-
Equivalent to `__local` of OpenCL or `__shared__ (<variable>)` of CUDA.
43-
"""
44-
macro LocalMemory(state, T, N)
45-
id = (shmem_counter[] += 1)
46-
quote
47-
LocalMemory($(esc(state)), $(esc(T)), Val($(esc(N))), Val($id))
48-
end
49-
end
50-
51-
"""
52-
Creates a block local array pointer with `T` being the element type
53-
and `N` the length. Both T and N need to be static! C is a counter for
54-
approriately get the correct Local mem id in CUDAnative.
55-
This is an internal method which needs to be overloaded by the GPU Array backends
56-
"""
57-
function LocalMemory(state, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
58-
error("Not implemented") # COV_EXCL_LINE
59-
end

src/device/device.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# device management and properties
2+
3+
export AbstractGPUDevice
4+
5+
abstract type AbstractGPUDevice end
6+
7+
"""
8+
Hardware threads of device
9+
"""
10+
threads(::AbstractGPUDevice) = error("Not implemented") # COV_EXCL_LINE

src/host/execution.jl renamed to src/device/execution.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# kernel execution
22

3-
export AbstractGPUBackend, gpu_call, synchronize, thread_blocks_heuristic
3+
export AbstractGPUBackend, AbstractKernelContext, gpu_call, synchronize, thread_blocks_heuristic
44

55
abstract type AbstractGPUBackend end
66

7+
abstract type AbstractKernelContext end
8+
79
backend(::Type{T}) where T = error("Can't choose GPU backend for $T")
810

911
"""
@@ -12,12 +14,12 @@ backend(::Type{T}) where T = error("Can't choose GPU backend for $T")
1214
Calls function `kernel` on the GPU.
1315
`A` must be an AbstractGPUArray and will help to dispatch to the correct GPU backend
1416
and supplies queues and contexts.
15-
Calls the kernel function with `kernel(state, args...)`, where state is dependant on the backend
16-
and can be used for getting an index into `A` with `linear_index(state)`.
17+
Calls the kernel function with `kernel(ctx, args...)`, where ctx is dependant on the backend
18+
and can be used for getting an index into `A` with `linear_index(ctx)`.
1719
Optionally, a launch configuration can be supplied in the following way:
1820
1921
1) A single integer, indicating how many work items (total number of threads) you want to launch.
20-
in this case `linear_index(state)` will be a number in the range `1:configuration`
22+
in this case `linear_index(ctx)` will be a number in the range `1:configuration`
2123
2) Pass a tuple of integer tuples to define blocks and threads per blocks!
2224
2325
"""
@@ -38,7 +40,7 @@ function gpu_call(kernel, A::AbstractArray, args::Tuple, configuration = length(
3840
Found: $configurations
3941
Configuration needs to be:
4042
1) A single integer, indicating how many work items (total number of threads) you want to launch.
41-
in this case `linear_index(state)` will be a number in the range 1:configuration
43+
in this case `linear_index(ctx)` will be a number in the range 1:configuration
4244
2) Pass a tuple of integer tuples to define blocks and threads per blocks!
4345
`linear_index` will be inbetween 1:prod((blocks..., threads...))
4446
""")

src/device/indexing.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,63 +5,63 @@ export global_size, synchronize_threads, linear_index
55

66
# thread indexing functions
77
for f in (:blockidx, :blockdim, :threadidx, :griddim)
8-
@eval $f(state)::Int = error("Not implemented") # COV_EXCL_LINE
8+
@eval $f(ctx::AbstractKernelContext)::Int = error("Not implemented") # COV_EXCL_LINE
99
@eval export $f
1010
end
1111

1212
"""
13-
global_size(state)
13+
global_size(ctx::AbstractKernelContext)
1414
1515
Global size == blockdim * griddim == total number of kernel execution
1616
"""
17-
@inline function global_size(state)
18-
griddim(state) * blockdim(state)
17+
@inline function global_size(ctx::AbstractKernelContext)
18+
griddim(ctx) * blockdim(ctx)
1919
end
2020

2121
"""
22-
linear_index(state)
22+
linear_index(ctx::AbstractKernelContext)
2323
2424
linear index corresponding to each kernel launch (in OpenCL equal to get_global_id).
2525
2626
"""
27-
@inline function linear_index(state)
28-
(blockidx(state) - 1) * blockdim(state) + threadidx(state)
27+
@inline function linear_index(ctx::AbstractKernelContext)
28+
(blockidx(ctx) - 1) * blockdim(ctx) + threadidx(ctx)
2929
end
3030

3131
"""
32-
linearidx(A, statesym = :state)
32+
linearidx(A, ctxsym = :ctx)
3333
3434
Macro form of `linear_index`, which calls return when out of bounds.
3535
So it can be used like this:
3636
3737
```julia
38-
function kernel(state, A)
39-
idx = @linear_index A state
38+
function kernel(ctx::AbstractKernelContext, A)
39+
idx = @linear_index A ctx
4040
# from here on it's save to index into A with idx
4141
@inbounds begin
4242
A[idx] = ...
4343
end
4444
end
4545
```
4646
"""
47-
macro linearidx(A, statesym = :state)
47+
macro linearidx(A, ctxsym = :ctx)
4848
quote
4949
x1 = $(esc(A))
50-
i1 = linear_index($(esc(statesym)))
50+
i1 = linear_index($(esc(ctxsym)))
5151
i1 > length(x1) && return
5252
i1
5353
end
5454
end
5555

5656
"""
57-
cartesianidx(A, statesym = :state)
57+
cartesianidx(A, ctxsym = :ctx)
5858
59-
Like [`@linearidx(A, statesym = :state)`](@ref), but returns an N-dimensional `NTuple{ndim(A), Int}` as index
59+
Like [`@linearidx(A, ctxsym = :ctx)`](@ref), but returns an N-dimensional `NTuple{ndim(A), Int}` as index
6060
"""
61-
macro cartesianidx(A, statesym = :state)
61+
macro cartesianidx(A, ctxsym = :ctx)
6262
quote
6363
x = $(esc(A))
64-
i2 = @linearidx(x, $(esc(statesym)))
64+
i2 = @linearidx(x, $(esc(ctxsym)))
6565
gpu_ind2sub(x, i2)
6666
end
6767
end

src/device/memory.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# on-device memory management
2+
3+
export @LocalMemory
4+
5+
6+
## thread-local array
7+
8+
const shmem_counter = Ref{Int}(0)
9+
10+
"""
11+
Creates a local static memory shared inside one block.
12+
Equivalent to `__local` of OpenCL or `__shared__ (<variable>)` of CUDA.
13+
"""
14+
macro LocalMemory(ctx, T, N)
15+
id = (shmem_counter[] += 1)
16+
quote
17+
LocalMemory($(esc(ctx)), $(esc(T)), Val($(esc(N))), Val($id))
18+
end
19+
end
20+
21+
"""
22+
Creates a block local array pointer with `T` being the element type
23+
and `N` the length. Both T and N need to be static! C is a counter for
24+
approriately get the correct Local mem id in CUDAnative.
25+
This is an internal method which needs to be overloaded by the GPU Array backends
26+
"""
27+
function LocalMemory(ctx, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
28+
error("Not implemented") # COV_EXCL_LINE
29+
end

src/device/synchronization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
export synchronize_threads
44

55
"""
6-
synchronize_threads(state)
6+
synchronize_threads(ctx::AbstractKernelContext)
77
88
in CUDA terms `__synchronize`
99
in OpenCL terms: `barrier(CLK_LOCAL_MEM_FENCE)`
1010
"""
11-
function synchronize_threads(state)
11+
function synchronize_threads(ctx::AbstractKernelContext)
1212
error("Not implemented") # COV_EXCL_LINE
1313
end

src/host/abstractarray.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ const AbstractGPUVector{T} = AbstractGPUArray{T, 1}
1515
const AbstractGPUMatrix{T} = AbstractGPUArray{T, 2}
1616
const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T, 1}, AbstractGPUArray{T, 2}}
1717

18+
"""
19+
device(A::AbstractArray)
20+
21+
Gets the device associated to the Array `A`
22+
"""
23+
device(A::AbstractArray) = error("Not implemented") # COV_EXCL_LINE
24+
1825

1926
# input/output
2027

@@ -136,8 +143,8 @@ end
136143
Base.copyto!(dest::AbstractGPUArray, src::AbstractGPUArray) =
137144
copyto!(dest, CartesianIndices(dest), src, CartesianIndices(src))
138145

139-
function copy_kernel!(state, dest, dest_offsets, src, src_offsets, shape, shape_dest, shape_source, length)
140-
i = linear_index(state)
146+
function copy_kernel!(ctx::AbstractKernelContext, dest, dest_offsets, src, src_offsets, shape, shape_dest, shape_source, length)
147+
i = linear_index(ctx)
141148
if i <= length
142149
# TODO can this be done faster and smarter?
143150
idx = gpu_ind2sub(shape, i)

src/host/base.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ end
6262
function Base.repeat(a::AbstractGPUVecOrMat, m::Int, n::Int = 1)
6363
o, p = size(a, 1), size(a, 2)
6464
b = similar(a, o*m, p*n)
65-
gpu_call(a, (b, a, o, p, m, n), n) do state, b, a, o, p, m, n
66-
j = linear_index(state)
65+
gpu_call(a, (b, a, o, p, m, n), n) do ctx, b, a, o, p, m, n
66+
j = linear_index(ctx)
6767
j > n && return
6868
d = (j - 1) * p + 1
6969
@inbounds for i in 1:m
@@ -82,8 +82,8 @@ end
8282
function Base.repeat(a::AbstractGPUVector, m::Int)
8383
o = length(a)
8484
b = similar(a, o*m)
85-
gpu_call(a, (b, a, o, m), m) do state, b, a, o, m
86-
i = linear_index(state)
85+
gpu_call(a, (b, a, o, m), m) do ctx, b, a, o, m
86+
i = linear_index(ctx)
8787
i > m && return
8888
c = (i - 1)*o + 1
8989
@inbounds for i in 1:o

src/host/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{Nothing})
4848
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
4949
bc′ = Broadcast.preprocess(dest, bc)
50-
gpu_call(dest, (dest, bc′)) do state, dest, bc′
50+
gpu_call(dest, (dest, bc′)) do ctx, dest, bc′
5151
let I = CartesianIndex(@cartesianidx(dest))
5252
@inbounds dest[I] = bc′[I]
5353
end

0 commit comments

Comments
 (0)