Skip to content

Commit 56d0c1f

Browse files
committed
Simplify gpu_call.
1 parent 7e330b9 commit 56d0c1f

File tree

12 files changed

+70
-73
lines changed

12 files changed

+70
-73
lines changed

src/device/execution.jl

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,47 +14,45 @@ Gets the GPUArrays back-end responsible for managing arrays of type `T`.
1414
backend(::Type{<:AbstractArray}) = error("This array is not a GPU array") # COV_EXCL_LINE
1515

1616
"""
17-
gpu_call(kernel::Function, A::AbstractGPUArray, args::Tuple, configuration = length(A))
17+
gpu_call(kernel::Function, A::AbstractGPUArray, args...; kwargs...)
1818
19-
Calls function `kernel` on the GPU.
20-
`A` must be an AbstractGPUArray and will help to dispatch to the correct GPU backend
21-
and supplies queues and contexts.
22-
Calls the kernel function with `kernel(ctx, args...)`, where ctx is dependant on the backend
23-
and can be used for getting an index into `A` with `linear_index(ctx)`.
24-
Optionally, a launch configuration can be supplied in the following way:
25-
26-
1) A single integer, indicating how many work items (total number of threads) you want to launch.
27-
in this case `linear_index(ctx)` will be a number in the range `1:configuration`
28-
2) Pass a tuple of integer tuples to define blocks and threads per blocks!
19+
Calls function `kernel` on the GPU device that backs array `A`, passing along arguments
20+
`args`. The keyword arguments `kwargs` are not passed along, but are interpreted on the host
21+
to influence how the kernel is executed. The following keyword arguments are supported:
2922
23+
- `total_threads::Int`: how many threads should be launched _in total_. The actual number of
24+
threads and blocks is determined using a heuristic. Defaults to the length of `A` if no
25+
other keyword arguments that influence the launch configuration are specified.
26+
- `threads::Int` and `blocks::Int`: configure exactly how many threads and blocks are
27+
launched. This cannot be used in combination with the `total_threads` argument.
3028
"""
31-
function gpu_call(kernel, A::AbstractArray, args::Tuple, configuration = length(A))
32-
ITuple = NTuple{N, Integer} where N
33-
# If is a single integer, we assume it to be the global size / total number of threads one wants to launch
34-
thread_blocks = if isa(configuration, Integer)
35-
thread_blocks_heuristic(configuration)
36-
elseif isa(configuration, ITuple)
37-
@assert length(configuration) == 1
38-
configuration[1], 1
39-
elseif isa(configuration, Tuple{ITuple, ITuple})
40-
@assert length(configuration[1]) == 1
41-
@assert length(configuration[2]) == 1
42-
configuration[1][1], configuration[2][1]
29+
function gpu_call(kernel::Base.Callable, A::AbstractArray, args...;
30+
total_threads::Union{Int,Nothing}=nothing,
31+
threads::Union{Int,Nothing}=nothing,
32+
blocks::Union{Int,Nothing}=nothing,
33+
kwargs...)
34+
# determine how many threads/blocks to launch
35+
if total_threads===nothing && threads===nothing && blocks===nothing
36+
total_threads = length(A)
37+
end
38+
if total_threads !== nothing
39+
if threads !== nothing || blocks !== nothing
40+
error("Cannot specify both total_threads and threads/blocks configuration")
41+
end
42+
threads, blocks = thread_blocks_heuristic(total_threads)
4343
else
44-
error("""Please launch a gpu kernel with a valid configuration.
45-
Found: $configurations
46-
Configuration needs to be:
47-
1) A single integer, indicating how many work items (total number of threads) you want to launch.
48-
in this case `linear_index(ctx)` will be a number in the range 1:configuration
49-
2) Pass a tuple of integer tuples to define blocks and threads per blocks!
50-
`linear_index` will be inbetween 1:prod((blocks..., threads...))
51-
""")
44+
if threads === nothing
45+
threads = 1
46+
end
47+
if blocks === nothing
48+
blocks = 1
49+
end
5250
end
53-
_gpu_call(backend(typeof(A)), kernel, A, args, thread_blocks)
51+
52+
gpu_call(backend(typeof(A)), kernel, args...; threads=threads, blocks=blocks, kwargs...)
5453
end
5554

56-
# Internal GPU call function, that needs to be overloaded by the backends.
57-
_gpu_call(::Any, f, A, args, thread_blocks) = error("Not implemented") # COV_EXCL_LINE
55+
gpu_call(backend::AbstractGPUBackend, kernel, args...; kwargs...) = error("Not implemented") # COV_EXCL_LINE
5856

5957
"""
6058
synchronize(A::AbstractArray)

src/host/abstractarray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ function Base.copyto!(dest::AbstractGPUArray{T, N}, destcrange::CartesianIndices
162162
dest_offsets = first.(destcrange.indices) .- 1
163163
src_offsets = first.(srccrange.indices) .- 1
164164
gpu_call(copy_kernel!, dest,
165-
(dest, dest_offsets, src, src_offsets, shape, size(dest), size(src), len),
166-
len)
165+
dest, dest_offsets, src, src_offsets, shape, size(dest), size(src), len;
166+
total_threads=len)
167167
dest
168168
end
169169

src/host/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262
function Base.repeat(a::AbstractGPUVecOrMat, m::Int, n::Int = 1)
6363
o, p = size(a, 1), size(a, 2)
6464
b = similar(a, o*m, p*n)
65-
gpu_call(a, (b, a, o, p, m, n), n) do ctx, b, a, o, p, m, n
65+
gpu_call(a, b, a, o, p, m, n; total_threads=n) do ctx, b, a, o, p, m, n
6666
j = linear_index(ctx)
6767
j > n && return
6868
d = (j - 1) * p + 1
@@ -82,7 +82,7 @@ end
8282
function Base.repeat(a::AbstractGPUVector, m::Int)
8383
o = length(a)
8484
b = similar(a, o*m)
85-
gpu_call(a, (b, a, o, m), m) do ctx, b, a, o, m
85+
gpu_call(a, b, a, o, m; total_threads=m) do ctx, b, a, o, m
8686
i = linear_index(ctx)
8787
i > m && return
8888
c = (i - 1)*o + 1

src/host/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{Nothing})
4848
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
4949
bc′ = Broadcast.preprocess(dest, bc)
50-
gpu_call(dest, (dest, bc′)) do ctx, dest, bc′
50+
gpu_call(dest, dest, bc′) do ctx, dest, bc′
5151
let I = CartesianIndex(@cartesianidx(dest))
5252
@inbounds dest[I] = bc′[I]
5353
end

src/host/construction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function Base.fill(X::Type{<: AbstractGPUArray{T}}, val, dims::NTuple{N, Integer
99
fill!(res, convert(T, val))
1010
end
1111
function Base.fill!(A::AbstractGPUArray{T}, x) where T
12-
gpu_call(A, (A, convert(T, x))) do ctx, a, val
12+
gpu_call(A, A, convert(T, x)) do ctx, a, val
1313
idx = @linearidx(a, ctx)
1414
@inbounds a[idx] = val
1515
return
@@ -30,7 +30,7 @@ end
3030

3131
function (T::Type{<: AbstractGPUArray})(s::UniformScaling, dims::Dims{2})
3232
res = zeros(T, dims)
33-
gpu_call(uniformscaling_kernel, res, (res, size(res, 1), s), minimum(dims))
33+
gpu_call(uniformscaling_kernel, res, res, size(res, 1), s; total_threads=minimum(dims))
3434
res
3535
end
3636
(T::Type{<: AbstractGPUArray})(s::UniformScaling, m::Integer, n::Integer) = T(s, Dims((m, n)))
@@ -67,7 +67,7 @@ function Base.convert(AT::Type{<: AbstractGPUArray}, iter)
6767
if isbits(iter) && isa(isize, Base.HasShape) && style != nothing && isa(ettrait, Base.HasEltype)
6868
# We can collect on the GPU
6969
A = similar(AT, eltype_or(AT, eltype(iter)), size(iter))
70-
gpu_call(collect_kernel, A, (A, iter, style))
70+
gpu_call(collect_kernel, A, A, iter, style)
7171
A
7272
else
7373
convert(AT, collect(iter))

src/host/indexing.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function Base._unsafe_getindex!(dest::AbstractGPUArray, src::AbstractGPUArray, I
9898
return dest
9999
end
100100
idims = map(length, Is)
101-
gpu_call(index_kernel, dest, (dest, src, idims, map(x-> to_index(dest, x), Is)))
101+
gpu_call(index_kernel, dest, dest, src, idims, map(x-> to_index(dest, x), Is))
102102
return dest
103103
end
104104

@@ -125,6 +125,7 @@ function Base._unsafe_setindex!(::IndexStyle, dest::T, src, Is::Union{Real, Abst
125125
idims = length.(Is)
126126
len = prod(idims)
127127
src_gpu = adapt(T, src)
128-
gpu_call(setindex_kernel!, dest, (dest, src_gpu, idims, map(x-> to_index(dest, x), Is), len), len)
128+
gpu_call(setindex_kernel!, dest, dest, src_gpu, idims, map(x-> to_index(dest, x), Is), len;
129+
total_threads=len)
129130
return dest
130131
end

src/host/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ end
115115
## high-level functionality
116116

117117
function LinearAlgebra.transpose!(At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
118-
gpu_call(At, (At, A)) do ctx, At, A
118+
gpu_call(At, At, A) do ctx, At, A
119119
idx = @cartesianidx A ctx
120120
@inbounds At[idx[2], idx[1]] = A[idx[1], idx[2]]
121121
return
@@ -129,7 +129,7 @@ end
129129

130130
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
131131
perm isa Tuple || (perm = Tuple(perm))
132-
gpu_call(dest, (dest, src, perm)) do ctx, dest, src, perm
132+
gpu_call(dest, dest, src, perm) do ctx, dest, src, perm
133133
I = @cartesianidx src ctx
134134
@inbounds dest[genperm(I, perm)...] = src[I...]
135135
return

src/host/mapreduce.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Base.count(pred::Function, A::AbstractGPUArray) = Int(mapreduce(pred, +, A; init
99

1010
Base.:(==)(A::AbstractGPUArray, B::AbstractGPUArray) = Bool(mapreduce(==, &, A, B; init = true))
1111

12-
LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = acc_mapreduce(==, &, true, A, (adjoint(A),))
12+
LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = acc_mapreduce(==, &, true, A, adjoint(A))
1313

1414
# hack to get around of fetching the first element of the AbstractGPUArray
1515
# as a startvalue, which is a bit complicated with the current reduce implementation
@@ -67,11 +67,11 @@ end
6767
function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUSrcArray, ::Colon)
6868
OT = gpu_promote_type(op, gpu_promote_type(f, eltype(A)))
6969
v0 = startvalue(op, OT) # TODO do this better
70-
acc_mapreduce(f, op, v0, A, ())
70+
acc_mapreduce(f, op, v0, A)
7171
end
7272

7373
function mapreduce_impl(f, op, nt::NamedTuple{(:init,)}, A::GPUSrcArray, ::Colon)
74-
acc_mapreduce(f, op, nt.init, A, ())
74+
acc_mapreduce(f, op, nt.init, A)
7575
end
7676

7777
function mapreduce_impl(f, op, nt, A::GPUSrcArray, dims)
@@ -80,10 +80,10 @@ end
8080

8181
function acc_mapreduce end
8282
function Base.mapreduce(f, op, A::GPUSrcArray, B::GPUSrcArray, C::Number; init)
83-
acc_mapreduce(f, op, init, A, (B, C))
83+
acc_mapreduce(f, op, init, A, B, C)
8484
end
8585
function Base.mapreduce(f, op, A::GPUSrcArray, B::GPUSrcArray; init)
86-
acc_mapreduce(f, op, init, A, (B,))
86+
acc_mapreduce(f, op, init, A, B)
8787
end
8888

8989
@generated function mapreducedim_kernel(ctx::AbstractKernelContext, f, op, R, A, range::NTuple{N, Any}) where N
@@ -118,7 +118,7 @@ end
118118

119119
function Base._mapreducedim!(f, op, R::AbstractGPUArray, A::GPUSrcArray)
120120
range = ifelse.(length.(axes(R)) .== 1, axes(A), nothing)
121-
gpu_call(mapreducedim_kernel, R, (f, op, R, A, range))
121+
gpu_call(mapreducedim_kernel, R, f, op, R, A, range)
122122
return R
123123
end
124124

@@ -165,17 +165,17 @@ for i = 0:10
165165

166166
end
167167

168-
function acc_mapreduce(f, op, v0::OT, A::GPUSrcArray, rest::Tuple) where {OT}
169-
blocksize = 80
168+
function acc_mapreduce(f, op, v0::OT, A::GPUSrcArray, rest...) where {OT}
169+
blocks = 80
170170
threads = 256
171-
if length(A) <= blocksize * threads
171+
if length(A) <= blocks * threads
172172
args = zip(convert_to_cpu(A), convert_to_cpu.(rest)...)
173173
return mapreduce(x-> f(x...), op, args, init = v0)
174174
end
175-
out = similar(A, OT, (blocksize,))
175+
out = similar(A, OT, (blocks,))
176176
fill!(out, v0)
177-
args = (f, op, v0, A, Val{threads}(), out, rest...)
178-
gpu_call(reduce_kernel, out, args, ((blocksize,), (threads,)))
177+
gpu_call(reduce_kernel, out, f, op, v0, A, Val{threads}(), out, rest...;
178+
threads=threads, blocks=blocks)
179179
reduce(op, Array(out))
180180
end
181181

src/host/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ function global_rng(A::AbstractGPUArray)
7777
end
7878

7979
function Random.rand!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
80-
gpu_call(A, (rng.state, A,)) do ctx, randstates, a
80+
gpu_call(A, rng.state, A) do ctx, randstates, a
8181
idx = linear_index(ctx)
8282
idx > length(a) && return
8383
@inbounds a[idx] = gpu_rand(T, ctx, randstates)

src/reference.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ end
5252
to_device(ctx, x::Tuple) = to_device.(Ref(ctx), x)
5353
to_device(ctx, x) = x
5454

55-
function GPUArrays._gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{Int, Int})
56-
blocks, threads = blocks_threads
55+
function GPUArrays.gpu_call(::JLBackend, f, args...; blocks::Int, threads::Int)
5756
ctx = JLKernelContext(threads, blocks)
5857
device_args = to_device.(Ref(ctx), args)
5958
tasks = Array{Task}(undef, threads)

0 commit comments

Comments
 (0)