Skip to content

Commit be52876

Browse files
author
William Moses
committed
wip
1 parent 95d5921 commit be52876

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,13 @@ function compile(job)
2424
mod, meta = CUDA.GPUCompiler.compile(:llvm, job)
2525
string(mod)
2626
end
27-
println(string(modstr))
28-
@show job
29-
@show job.params
30-
@show job.source
31-
kernel = LLVMFunc{F,tt}(f, modstr)
32-
return modstr
33-
#=
3427
# check if we'll need the device runtime
3528
undefined_fs = filter(collect(functions(meta.ir))) do f
36-
isdeclaration(f) && !LLVM.isintrinsic(f)
29+
isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
3730
end
3831
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
3932
"__nvvm_reflect" #= TODO: should have been optimized away =#]
40-
needs_cudadevrt = !isempty(setdiff(LLVM.name.(undefined_fs), intrinsic_fns))
33+
needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns))
4134

4235
# prepare invocations of CUDA compiler tools
4336
ptxas_opts = String[]
@@ -59,7 +52,7 @@ function compile(job)
5952
arch = "sm_$(cap.major)$(cap.minor)"
6053

6154
# validate use of parameter memory
62-
argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt
55+
argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt
6356
!isghosttype(dt) && !Core.Compiler.isconstType(dt)
6457
end
6558
param_usage = sum(sizeof, argtypes)
@@ -120,7 +113,7 @@ function compile(job)
120113
"--output-file", ptxas_output,
121114
ptx_input
122115
])
123-
proc, log = run_and_collect(`$(ptxas()) $ptxas_opts`)
116+
proc, log = CUDA.run_and_collect(`$(ptxas()) $ptxas_opts`)
124117
log = strip(log)
125118
if !success(proc)
126119
reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" :
@@ -139,8 +132,7 @@ function compile(job)
139132
@debug "PTX compiler log:\n" * log
140133
end
141134
rm(ptx_input)
142-
=#
143-
#=
135+
144136
# link device libraries, if necessary
145137
#
146138
# this requires relocatable device code, which prevents certain optimizations and
@@ -180,8 +172,12 @@ function compile(job)
180172
image = read(ptxas_output)
181173
rm(ptxas_output)
182174
end
183-
=#
184-
return (image, entry=LLVM.name(meta.entry))
175+
176+
println(string(modstr))
177+
@show job
178+
@show job.source
179+
@show job.config
180+
LLVMFunc{F,job.source.specTypes}(f, modstr, image, LLVM.name(meta.entry))
185181
end
186182

187183
# link into an executable kernel
@@ -193,10 +189,24 @@ end
193189
struct LLVMFunc{F,tt}
194190
f::F
195191
mod::String
192+
image
193+
entry::String
196194
end
197195

198-
function (func::LLVMFunc{F,tt})(args...) where{F, tt}
199-
196+
function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuDim=1,
197+
shmem::Integer=0) where{F, tt}
198+
blockdim = CUDA.CuDim3(blocks)
199+
threaddim = CUDA.CuDim3(threads)
200+
201+
@show args
202+
203+
# void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
204+
# size_t opaque_len, XlaCustomCallStatus* status) {
205+
206+
CUDA.cuLaunchKernel(f,
207+
blockdim.x, blockdim.y, blockdim.z,
208+
threaddim.x, threaddim.y, threaddim.z,
209+
shmem, stream, kernelParams, C_NULL)
200210
end
201211

202212
# cache of compilation caches, per context

0 commit comments

Comments
 (0)