Skip to content

Commit ba52bbb

Browse files
committed
Better launch interface
1 parent 1953e8d commit ba52bbb

File tree

3 files changed

+123
-8
lines changed

3 files changed

+123
-8
lines changed

src/KernelAbstractions.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ function unsafe_free! end
194194

195195
unsafe_free!(::AbstractArray) = return
196196

197+
"""
198+
Abstract type for all KernelAbstractions backends.
199+
"""
200+
abstract type Backend end
201+
197202
include("intrinsics.jl")
198203
import .KernelIntrinsics
199204
export KernelIntrinsics
@@ -500,11 +505,6 @@ constify(arg) = adapt(ConstAdaptor(), arg)
500505
# Backend hierarchy
501506
###
502507

503-
"""
504-
505-
Abstract type for all KernelAbstractions backends.
506-
"""
507-
abstract type Backend end
508508

509509
"""
510510
Abstract type for all GPU based KernelAbstractions backends.

src/intrinsics.jl

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module KernelIntrinsics
22

3+
import ..KernelAbstractions: Backend
4+
import GPUCompiler: split_kwargs, assign_args!
5+
36
"""
47
get_global_size()::@NamedTuple{x::Int, y::Int, z::Int}
58
@@ -100,7 +103,6 @@ kernel on the host.
100103
!!! note
101104
Backend implementations **must** implement:
102105
```
103-
KI.KIKernel(::NewBackend, f, args...; kwargs...)
104106
(kernel::KIKernel{<:NewBackend})(args...; numworkgroups=nothing, workgroupsize=nothing, kwargs...)
105107
```
106108
As well as the on-device functionality.
@@ -157,4 +159,88 @@ Used for certain algorithm optimizations.
157159
As well as the on-device functionality.
158160
"""
159161
multiprocessor_count(_) = 0
162+
163+
# TODO: docstring
164+
# kiconvert(::NewBackend, arg)
165+
function kiconvert end
166+
167+
# TODO: docstring
168+
# KI.kifunction(::NewBackend, f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT}
169+
function kifunction end
170+
171+
const MACRO_KWARGS = [:launch, :backend]
172+
const COMPILER_KWARGS = [:kernel, :name, :always_inline]
173+
const LAUNCH_KWARGS = [:numworkgroups, :workgroupsize]
174+
175+
macro kikernel(backend, ex...)
176+
call = ex[end]
177+
kwargs = map(ex[1:end-1]) do kwarg
178+
if kwarg isa Symbol
179+
:($kwarg = $kwarg)
180+
elseif Meta.isexpr(kwarg, :(=))
181+
kwarg
182+
else
183+
throw(ArgumentError("Invalid keyword argument '$kwarg'"))
184+
end
185+
end
186+
187+
# destructure the kernel call
188+
Meta.isexpr(call, :call) || throw(ArgumentError("final argument to @kikern should be a function call"))
189+
f = call.args[1]
190+
args = call.args[2:end]
191+
192+
code = quote end
193+
vars, var_exprs = assign_args!(code, args)
194+
195+
# group keyword argument
196+
macro_kwargs, compiler_kwargs, call_kwargs, other_kwargs =
197+
split_kwargs(kwargs, MACRO_KWARGS, COMPILER_KWARGS, LAUNCH_KWARGS)
198+
if !isempty(other_kwargs)
199+
key,val = first(other_kwargs).args
200+
throw(ArgumentError("Unsupported keyword argument '$key'"))
201+
end
202+
203+
# handle keyword arguments that influence the macro's behavior
204+
launch = true
205+
for kwarg in macro_kwargs
206+
key,val = kwarg.args
207+
if key === :launch
208+
isa(val, Bool) || throw(ArgumentError("`launch` keyword argument to @kikern should be a Bool"))
209+
launch = val::Bool
210+
else
211+
throw(ArgumentError("Unsupported keyword argument '$key'"))
212+
end
213+
end
214+
if !launch && !isempty(call_kwargs)
215+
error("@kikern with launch=false does not support launch-time keyword arguments; use them when calling the kernel")
216+
end
217+
218+
# FIXME: macro hygiene wrt. escaping kwarg values (this broke with 1.5)
219+
# we esc() the whole thing now, necessitating gensyms...
220+
@gensym f_var kernel_f kernel_args kernel_tt kernel
221+
222+
# convert the arguments, call the compiler and launch the kernel
223+
# while keeping the original arguments alive
224+
push!(code.args,
225+
quote
226+
$f_var = $f
227+
GC.@preserve $(vars...) $f_var begin
228+
$kernel_f = $kiconvert($backend, $f_var)
229+
$kernel_args = map(x -> $kiconvert($backend, x), ($(var_exprs...),))
230+
$kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...}
231+
$kernel = $kifunction($backend, $kernel_f, $kernel_tt; $(compiler_kwargs...))
232+
if $launch
233+
$kernel($(var_exprs...); $(call_kwargs...))
234+
end
235+
$kernel
236+
end
237+
end)
238+
239+
return esc(quote
240+
let
241+
$code
242+
end
243+
end)
244+
end
245+
160246
end

src/pocl/backend.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module POCLKernels
22

33
using ..POCL
44
using ..POCL: @device_override, cl, method_table
5-
using ..POCL: device
5+
using ..POCL: device, clconvert, clfunction
66

77
import KernelAbstractions as KA
88

@@ -138,9 +138,38 @@ function (obj::KA.Kernel{POCLBackend})(args...; ndrange = nothing, workgroupsize
138138
return nothing
139139
end
140140

141+
const KI = KA.KernelIntrinsics
142+
143+
KI.kiconvert(::POCLBackend, arg) = clconvert(arg)
144+
145+
function KI.kifunction(::POCLBackend, f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT}
146+
kern = clfunction(f, tt; name, kwargs...)
147+
KI.KIKernel{POCLBackend, typeof(kern)}(POCLBackend(), kern)
148+
end
149+
150+
function (obj::KI.KIKernel{POCLBackend})(args...; numworkgroups=nothing, workgroupsize=nothing, kwargs...)
151+
local_size = isnothing(workgroupsize) ? 1 : workgroupsize
152+
global_size = if isnothing(numworkgroups)
153+
1
154+
else
155+
numworkgroups*local_size
156+
end
157+
158+
obj.kern(args...; local_size, global_size)
159+
end
160+
161+
162+
function KI.kernel_max_work_group_size(::POCLBackend, kikern::KI.KIKernel{<:POCLBackend}; max_work_items::Int=typemax(Int))::Int
163+
4096
164+
end
165+
function KI.max_work_group_size(::POCLBackend)::Int
166+
4096
167+
end
168+
function KI.multiprocessor_count(::POCLBackend)::Int
169+
1
170+
end
141171

142172
## Indexing Functions
143-
const KI = KA.KernelIntrinsics
144173

145174
@device_override @inline function KI.get_local_id()
146175
return (; x = Int(get_local_id(1)), y = Int(get_local_id(2)), z = Int(get_local_id(3)))

0 commit comments

Comments
 (0)