|
1 | 1 | module KernelIntrinsics |
2 | 2 |
|
| 3 | +import ..KernelAbstractions: Backend |
| 4 | +import GPUCompiler: split_kwargs, assign_args! |
| 5 | + |
3 | 6 | """ |
4 | 7 | get_global_size()::@NamedTuple{x::Int, y::Int, z::Int} |
5 | 8 |
|
@@ -100,7 +103,6 @@ kernel on the host. |
100 | 103 | !!! note |
101 | 104 | Backend implementations **must** implement: |
102 | 105 | ``` |
103 | | - KI.KIKernel(::NewBackend, f, args...; kwargs...) |
104 | 106 | (kernel::KIKernel{<:NewBackend})(args...; numworkgroups=nothing, workgroupsize=nothing, kwargs...) |
105 | 107 | ``` |
106 | 108 | As well as the on-device functionality. |
@@ -157,4 +159,88 @@ Used for certain algorithm optimizations. |
157 | 159 | As well as the on-device functionality. |
158 | 160 | """ |
159 | 161 | multiprocessor_count(_) = 0 |
| 162 | + |
| 163 | +# TODO: docstring |
| 164 | +# kiconvert(::NewBackend, arg) |
| 165 | +function kiconvert end |
| 166 | + |
| 167 | +# TODO: docstring |
| 168 | +# KI.kifunction(::NewBackend, f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT} |
| 169 | +function kifunction end |
| 170 | + |
| 171 | +const MACRO_KWARGS = [:launch, :backend] |
| 172 | +const COMPILER_KWARGS = [:kernel, :name, :always_inline] |
| 173 | +const LAUNCH_KWARGS = [:numworkgroups, :workgroupsize] |
| 174 | + |
| 175 | +macro kikernel(backend, ex...) |
| 176 | + call = ex[end] |
| 177 | + kwargs = map(ex[1:end-1]) do kwarg |
| 178 | + if kwarg isa Symbol |
| 179 | + :($kwarg = $kwarg) |
| 180 | + elseif Meta.isexpr(kwarg, :(=)) |
| 181 | + kwarg |
| 182 | + else |
| 183 | + throw(ArgumentError("Invalid keyword argument '$kwarg'")) |
| 184 | + end |
| 185 | + end |
| 186 | + |
| 187 | + # destructure the kernel call |
| 188 | + Meta.isexpr(call, :call) || throw(ArgumentError("final argument to @kikern should be a function call")) |
| 189 | + f = call.args[1] |
| 190 | + args = call.args[2:end] |
| 191 | + |
| 192 | + code = quote end |
| 193 | + vars, var_exprs = assign_args!(code, args) |
| 194 | + |
| 195 | + # group keyword argument |
| 196 | + macro_kwargs, compiler_kwargs, call_kwargs, other_kwargs = |
| 197 | + split_kwargs(kwargs, MACRO_KWARGS, COMPILER_KWARGS, LAUNCH_KWARGS) |
| 198 | + if !isempty(other_kwargs) |
| 199 | + key,val = first(other_kwargs).args |
| 200 | + throw(ArgumentError("Unsupported keyword argument '$key'")) |
| 201 | + end |
| 202 | + |
| 203 | + # handle keyword arguments that influence the macro's behavior |
| 204 | + launch = true |
| 205 | + for kwarg in macro_kwargs |
| 206 | + key,val = kwarg.args |
| 207 | + if key === :launch |
| 208 | + isa(val, Bool) || throw(ArgumentError("`launch` keyword argument to @kikern should be a Bool")) |
| 209 | + launch = val::Bool |
| 210 | + else |
| 211 | + throw(ArgumentError("Unsupported keyword argument '$key'")) |
| 212 | + end |
| 213 | + end |
| 214 | + if !launch && !isempty(call_kwargs) |
| 215 | + error("@kikern with launch=false does not support launch-time keyword arguments; use them when calling the kernel") |
| 216 | + end |
| 217 | + |
| 218 | + # FIXME: macro hygiene wrt. escaping kwarg values (this broke with 1.5) |
| 219 | + # we esc() the whole thing now, necessitating gensyms... |
| 220 | + @gensym f_var kernel_f kernel_args kernel_tt kernel |
| 221 | + |
| 222 | + # convert the arguments, call the compiler and launch the kernel |
| 223 | + # while keeping the original arguments alive |
| 224 | + push!(code.args, |
| 225 | + quote |
| 226 | + $f_var = $f |
| 227 | + GC.@preserve $(vars...) $f_var begin |
| 228 | + $kernel_f = $kiconvert($backend, $f_var) |
| 229 | + $kernel_args = map(x -> $kiconvert($backend, x), ($(var_exprs...),)) |
| 230 | + $kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...} |
| 231 | + $kernel = $kifunction($backend, $kernel_f, $kernel_tt; $(compiler_kwargs...)) |
| 232 | + if $launch |
| 233 | + $kernel($(var_exprs...); $(call_kwargs...)) |
| 234 | + end |
| 235 | + $kernel |
| 236 | + end |
| 237 | + end) |
| 238 | + |
| 239 | + return esc(quote |
| 240 | + let |
| 241 | + $code |
| 242 | + end |
| 243 | + end) |
| 244 | +end |
| 245 | + |
160 | 246 | end |
0 commit comments