@@ -24,20 +24,13 @@ function compile(job)
2424 mod, meta = CUDA. GPUCompiler. compile (:llvm , job)
2525 string (mod)
2626 end
27- println (string (modstr))
28- @show job
29- @show job. params
30- @show job. source
31- kernel = LLVMFunc {F,tt} (f, modstr)
32- return modstr
33- #=
3427 # check if we'll need the device runtime
3528 undefined_fs = filter (collect (functions (meta. ir))) do f
36- isdeclaration(f) && !LLVM.isintrinsic(f)
29+ isdeclaration (f) && ! CUDA . LLVM. isintrinsic (f)
3730 end
3831 intrinsic_fns = [" vprintf" , " malloc" , " free" , " __assertfail" ,
3932 " __nvvm_reflect" #= TODO : should have been optimized away =# ]
40- needs_cudadevrt = !isempty(setdiff(LLVM.name.(undefined_fs), intrinsic_fns))
33+ needs_cudadevrt = ! isempty (setdiff (CUDA . LLVM. name .(undefined_fs), intrinsic_fns))
4134
4235 # prepare invocations of CUDA compiler tools
4336 ptxas_opts = String[]
@@ -59,7 +52,7 @@ function compile(job)
5952 arch = " sm_$(cap. major)$(cap. minor) "
6053
6154 # validate use of parameter memory
62- argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt
55+ argtypes = filter ([CUDA . KernelState, job. source. specTypes. parameters... ]) do dt
6356 ! isghosttype (dt) && ! Core. Compiler. isconstType (dt)
6457 end
6558 param_usage = sum (sizeof, argtypes)
@@ -120,7 +113,7 @@ function compile(job)
120113 " --output-file" , ptxas_output,
121114 ptx_input
122115 ])
123- proc, log = run_and_collect(`$(ptxas()) $ptxas_opts`)
116+ proc, log = CUDA . run_and_collect (` $(ptxas ()) $ptxas_opts ` )
124117 log = strip (log)
125118 if ! success (proc)
126119 reason = proc. termsignal > 0 ? " ptxas received signal $(proc. termsignal) " :
@@ -139,8 +132,7 @@ function compile(job)
139132 @debug " PTX compiler log:\n " * log
140133 end
141134 rm (ptx_input)
142- =#
143- #=
135+
144136 # link device libraries, if necessary
145137 #
146138 # this requires relocatable device code, which prevents certain optimizations and
@@ -180,8 +172,12 @@ function compile(job)
180172 image = read (ptxas_output)
181173 rm (ptxas_output)
182174 end
183- =#
184- return (image, entry= LLVM. name (meta. entry))
175+
176+ println (string (modstr))
177+ @show job
178+ @show job. source
179+ @show job. config
180+ LLVMFunc {F,job.source.specTypes} (f, modstr, image, LLVM. name (meta. entry))
185181end
186182
187183# link into an executable kernel
@@ -193,10 +189,24 @@ end
193189struct LLVMFunc{F,tt}
194190 f:: F
195191 mod:: String
192+ image
193+ entry:: String
196194end
197195
198- function (func:: LLVMFunc{F,tt} )(args... ) where {F, tt}
199-
196+ function (func:: LLVMFunc{F,tt} )(args... ; blocks:: CUDA.CuDim = 1 , threads:: CUDA.CuDim = 1 ,
197+ shmem:: Integer = 0 ) where {F, tt}
198+ blockdim = CUDA. CuDim3 (blocks)
199+ threaddim = CUDA. CuDim3 (threads)
200+
201+ @show args
202+
203+ # void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
204+ # size_t opaque_len, XlaCustomCallStatus* status) {
205+
206+ CUDA. cuLaunchKernel (f,
207+ blockdim. x, blockdim. y, blockdim. z,
208+ threaddim. x, threaddim. y, threaddim. z,
209+ shmem, stream, kernelParams, C_NULL )
200210end
201211
202212# cache of compilation caches, per context
0 commit comments