|
| 1 | +function kernelFunc(funcExpr) |
| 2 | + if @capture(funcExpr, function fname_(fargs__) where Targs__ fbody__ end) |
| 3 | + kernelfunc = quote |
| 4 | + function $fname(args::Tuple{WgpuArray}, workgroupSizes, workgroupCount) |
| 5 | + $preparePipeline($(funcExpr), args...) |
| 6 | + $compute($(funcExpr), args...; workgroupSizes=workgroupSizes, workgroupCount=workgroupCount) |
| 7 | + return nothing |
| 8 | + end |
| 9 | + end |> unblock |
| 10 | + return esc(quote $kernelfunc end) |
| 11 | + else |
| 12 | + error("Couldnt capture function") |
| 13 | + end |
| 14 | +end |
| 15 | + |
| 16 | +function getFunctionBlock(func, args) |
| 17 | + fString = CodeTracking.definition(String, which(func, args)) |
| 18 | + return Meta.parse(fString |> first) |
| 19 | +end |
| 20 | + |
| 21 | +function wgpuCall(kernelObj::WGPUKernelObject, args...) |
| 22 | + kernelObj.kernelFunc(args...) |
| 23 | +end |
| 24 | + |
| 25 | +macro wgpukernel(launch, wgSize, wgCount, ex) |
| 26 | + code = quote end |
| 27 | + @gensym f_var kernel_f kernel_args kernel_tt kernel |
| 28 | + if @capture(ex, fname_(fargs__)) |
| 29 | + (vars, var_exprs) = assign_args!(code, fargs) |
| 30 | + push!(code.args, quote |
| 31 | + $kernel_args = ($(var_exprs...),) |
| 32 | + $kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...} |
| 33 | + kernel = function wgpuKernel(args...) |
| 34 | + $preparePipeline($fname, args...; workgroupSizes=$wgSize, workgroupCount=$wgCount) |
| 35 | + $compute($fname, args...; workgroupSizes=$wgSize, workgroupCount=$wgCount) |
| 36 | + end |
| 37 | + if $launch == true |
| 38 | + wgpuCall(WGPUKernelObject(kernel), $(kernel_args)...) |
| 39 | + else |
| 40 | + WGPUKernelObject(kernel) |
| 41 | + end |
| 42 | + end |
| 43 | + ) |
| 44 | + end |
| 45 | + esc(code) |
| 46 | +end |
0 commit comments