@@ -42,28 +42,38 @@ function gpu_call(kernel::Base.Callable, args...;
42
42
threads:: Union{Int,Nothing} = nothing ,
43
43
blocks:: Union{Int,Nothing} = nothing ,
44
44
name:: Union{String,Nothing} = nothing )
45
- # determine how many threads/blocks to launch
45
+ # non-trivial default values for launch configuration
46
46
if total_threads=== nothing && threads=== nothing && blocks=== nothing
47
47
total_threads = length (target)
48
- end
49
- if total_threads != = nothing
50
- if threads != = nothing || blocks != = nothing
51
- error (" Cannot specify both total_threads and threads/blocks configuration" )
52
- end
53
- blocks, threads = thread_blocks_heuristic (total_threads)
54
- else
48
+ elseif total_threads=== nothing
55
49
if threads === nothing
56
50
threads = 1
57
51
end
58
52
if blocks === nothing
59
53
blocks = 1
60
54
end
55
+ elseif threads!= = nothing || blocks!= = nothing
56
+ error (" Cannot specify both total_threads and threads/blocks configuration" )
57
+ end
58
+
59
+ if total_threads != = nothing
60
+ gpu_call (backend (target), kernel, args, total_threads; name= name)
61
+ else
62
+ gpu_call (backend (target), kernel, args, threads, blocks; name= name)
61
63
end
64
+ end
62
65
63
- gpu_call (backend (target), kernel, args... ; threads= threads, blocks= blocks, name= name)
66
+ # gpu_call method with a simple launch configuration heuristic.
67
+ # this can be specialised if more sophisticated heuristics are available.
68
+ function gpu_call (backend:: AbstractGPUBackend , kernel, args, total_threads:: Int ; kwargs... )
69
+ threads = clamp (total_threads, 1 , 256 )
70
+ blocks = max (ceil (Int, total_threads / threads), 1 )
71
+
72
+ gpu_call (backend, kernel, args, threads, blocks; kwargs... )
64
73
end
65
74
66
- gpu_call (backend:: AbstractGPUBackend , kernel, args... ; kwargs... ) = error (" Not implemented" ) # COV_EXCL_LINE
75
+ # bottom-line gpu_call method that is expected to be implemented by the back end
76
+ gpu_call (backend:: AbstractGPUBackend , kernel, args, threads:: Int , blocks:: Int ; kwargs... ) = error (" Not implemented" ) # COV_EXCL_LINE
67
77
68
78
"""
69
79
synchronize(A::AbstractArray)
@@ -74,10 +84,3 @@ function synchronize(A::AbstractArray)
74
84
# fallback is a noop, for backends not needing synchronization. This
75
85
# makes it easier to write generic code that also works for AbstractArrays
76
86
end
77
-
78
- function thread_blocks_heuristic (len:: Integer )
79
- # TODO better threads default
80
- threads = clamp (len, 1 , 256 )
81
- blocks = max (ceil (Int, len / threads), 1 )
82
- (blocks, threads)
83
- end
0 commit comments