Skip to content

Commit f6bde2a

Browse files
committed
Put the launch configuration heuristic in gpu_call.
This makes it possible to overload without having to do certain work twice (e.g., converting arguments to their kernel equivalents).
1 parent deb8ecb commit f6bde2a

File tree

3 files changed

+21
-24
lines changed

3 files changed

+21
-24
lines changed

src/device/execution.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,38 @@ function gpu_call(kernel::Base.Callable, args...;
4242
threads::Union{Int,Nothing}=nothing,
4343
blocks::Union{Int,Nothing}=nothing,
4444
name::Union{String,Nothing}=nothing)
45-
# determine how many threads/blocks to launch
45+
# non-trivial default values for launch configuration
4646
if total_threads===nothing && threads===nothing && blocks===nothing
4747
total_threads = length(target)
48-
end
49-
if total_threads !== nothing
50-
if threads !== nothing || blocks !== nothing
51-
error("Cannot specify both total_threads and threads/blocks configuration")
52-
end
53-
blocks, threads = thread_blocks_heuristic(total_threads)
54-
else
48+
elseif total_threads===nothing
5549
if threads === nothing
5650
threads = 1
5751
end
5852
if blocks === nothing
5953
blocks = 1
6054
end
55+
elseif threads!==nothing || blocks!==nothing
56+
error("Cannot specify both total_threads and threads/blocks configuration")
57+
end
58+
59+
if total_threads !== nothing
60+
gpu_call(backend(target), kernel, args, total_threads; name=name)
61+
else
62+
gpu_call(backend(target), kernel, args, threads, blocks; name=name)
6163
end
64+
end
6265

63-
gpu_call(backend(target), kernel, args...; threads=threads, blocks=blocks, name=name)
66+
# gpu_call method with a simple launch configuration heuristic.
67+
# this can be specialised if more sophisticated heuristics are available.
68+
function gpu_call(backend::AbstractGPUBackend, kernel, args, total_threads::Int; kwargs...)
69+
threads = clamp(total_threads, 1, 256)
70+
blocks = max(ceil(Int, total_threads / threads), 1)
71+
72+
gpu_call(backend, kernel, args, threads, blocks; kwargs...)
6473
end
6574

66-
gpu_call(backend::AbstractGPUBackend, kernel, args...; kwargs...) = error("Not implemented") # COV_EXCL_LINE
75+
# bottom-line gpu_call method that is expected to be implemented by the back end
76+
gpu_call(backend::AbstractGPUBackend, kernel, args, threads::Int, blocks::Int; kwargs...) = error("Not implemented") # COV_EXCL_LINE
6777

6878
"""
6979
synchronize(A::AbstractArray)
@@ -74,10 +84,3 @@ function synchronize(A::AbstractArray)
7484
# fallback is a noop, for backends not needing synchronization. This
7585
# makes it easier to write generic code that also works for AbstractArrays
7686
end
77-
78-
function thread_blocks_heuristic(len::Integer)
79-
# TODO better threads default
80-
threads = clamp(len, 1, 256)
81-
blocks = max(ceil(Int, len / threads), 1)
82-
(blocks, threads)
83-
end

src/reference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ end
6464
Base.getindex(r::JlRefValue) = r.x
6565
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))
6666

67-
function GPUArrays.gpu_call(::JLBackend, f, args...; blocks::Int, threads::Int,
67+
function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int;
6868
name::Union{String,Nothing})
6969
ctx = JLKernelContext(threads, blocks)
7070
device_args = jlconvert.(args)

test/testsuite/base.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,6 @@ function test_base(AT)
143143
@test compare(a-> repeat(a, 4, 3), AT, rand(Float32, 10, 15))
144144
end
145145

146-
@testset "heuristics" begin
147-
blocks, threads = thread_blocks_heuristic(0)
148-
@test blocks == 1
149-
@test threads == 1
150-
end
151-
152146
@testset "permutedims" begin
153147
@test compare(x->permutedims(x, [1, 2]), AT, rand(4, 4))
154148

0 commit comments

Comments
 (0)