Skip to content

Commit 02b3fb8

Browse files
authored
Merge pull request #243 from JuliaGPU/tb/gpu_call
Breaking gpu_call interface changes
2 parents 4463977 + f6bde2a commit 02b3fb8

File tree

4 files changed

+26
-26
lines changed

4 files changed

+26
-26
lines changed

src/device/execution.jl

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,47 @@ host to influence how the kernel is executed. The following keyword arguments ar
3333
no other keyword arguments that influence the launch configuration are specified.
3434
- `threads::Int` and `blocks::Int`: configure exactly how many threads and blocks are
3535
launched. This cannot be used in combination with the `total_threads` argument.
36+
- `name::String`: inform the back end about the name of the kernel to be executed.
37+
This can be used to emit better diagnostics, and is useful with anonymous kernels.
3638
"""
3739
function gpu_call(kernel::Base.Callable, args...;
3840
target::AbstractArray=first(args),
3941
total_threads::Union{Int,Nothing}=nothing,
4042
threads::Union{Int,Nothing}=nothing,
4143
blocks::Union{Int,Nothing}=nothing,
42-
kwargs...)
43-
# determine how many threads/blocks to launch
44+
name::Union{String,Nothing}=nothing)
45+
# non-trivial default values for launch configuration
4446
if total_threads===nothing && threads===nothing && blocks===nothing
4547
total_threads = length(target)
46-
end
47-
if total_threads !== nothing
48-
if threads !== nothing || blocks !== nothing
49-
error("Cannot specify both total_threads and threads/blocks configuration")
50-
end
51-
blocks, threads = thread_blocks_heuristic(total_threads)
52-
else
48+
elseif total_threads===nothing
5349
if threads === nothing
5450
threads = 1
5551
end
5652
if blocks === nothing
5753
blocks = 1
5854
end
55+
elseif threads!==nothing || blocks!==nothing
56+
error("Cannot specify both total_threads and threads/blocks configuration")
57+
end
58+
59+
if total_threads !== nothing
60+
gpu_call(backend(target), kernel, args, total_threads; name=name)
61+
else
62+
gpu_call(backend(target), kernel, args, threads, blocks; name=name)
5963
end
64+
end
6065

61-
gpu_call(backend(target), kernel, args...; threads=threads, blocks=blocks, kwargs...)
66+
# gpu_call method with a simple launch configuration heuristic.
67+
# this can be specialised if more sophisticated heuristics are available.
68+
function gpu_call(backend::AbstractGPUBackend, kernel, args, total_threads::Int; kwargs...)
69+
threads = clamp(total_threads, 1, 256)
70+
blocks = max(ceil(Int, total_threads / threads), 1)
71+
72+
gpu_call(backend, kernel, args, threads, blocks; kwargs...)
6273
end
6374

64-
gpu_call(backend::AbstractGPUBackend, kernel, args...; kwargs...) = error("Not implemented") # COV_EXCL_LINE
75+
# bottom-line gpu_call method that is expected to be implemented by the back end
76+
gpu_call(backend::AbstractGPUBackend, kernel, args, threads::Int, blocks::Int; kwargs...) = error("Not implemented") # COV_EXCL_LINE
6577

6678
"""
6779
synchronize(A::AbstractArray)
@@ -72,10 +84,3 @@ function synchronize(A::AbstractArray)
7284
# fallback is a noop, for backends not needing synchronization. This
7385
# makes it easier to write generic code that also works for AbstractArrays
7486
end
75-
76-
function thread_blocks_heuristic(len::Integer)
77-
# TODO better threads default
78-
threads = clamp(len, 1, 256)
79-
blocks = max(ceil(Int, len / threads), 1)
80-
(blocks, threads)
81-
end

src/host/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{Nothing})
6161
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
6262
bc′ = Broadcast.preprocess(dest, bc)
63-
gpu_call(dest, bc′) do ctx, dest, bc′
63+
gpu_call(dest, bc′; name="broadcast") do ctx, dest, bc′
6464
let I = CartesianIndex(@cartesianidx(dest))
6565
#@inbounds dest[I] = bc′[I]
6666
@inbounds let

src/reference.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ end
6464
Base.getindex(r::JlRefValue) = r.x
6565
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))
6666

67-
function GPUArrays.gpu_call(::JLBackend, f, args...; blocks::Int, threads::Int)
67+
function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int;
68+
name::Union{String,Nothing})
6869
ctx = JLKernelContext(threads, blocks)
6970
device_args = jlconvert.(args)
7071
tasks = Array{Task}(undef, threads)

test/testsuite/base.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,6 @@ function test_base(AT)
143143
@test compare(a-> repeat(a, 4, 3), AT, rand(Float32, 10, 15))
144144
end
145145

146-
@testset "heuristics" begin
147-
blocks, threads = thread_blocks_heuristic(0)
148-
@test blocks == 1
149-
@test threads == 1
150-
end
151-
152146
@testset "permutedims" begin
153147
@test compare(x->permutedims(x, [1, 2]), AT, rand(4, 4))
154148

0 commit comments

Comments
 (0)