Skip to content

Commit 38bd2dc

Browse files
add support for templates in cuda-python backend
1 parent 23e8c17 commit 38bd2dc

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

kernel_tuner/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, kernel_name, kernel_sources, lang, defines=None):
6868
lang = util.detect_language(kernel_string)
6969

7070
# The validity of lang is checked later, when creating the DeviceInterface
71-
self.lang = lang
71+
self.lang = lang.upper()
7272

7373
def get_kernel_string(self, index=0, params=None):
7474
""" retrieve the kernel source with the given index and return as a string
@@ -529,7 +529,7 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose)
529529
kernel_options.block_size_names)
530530

531531
#check for templated kernel
532-
if kernel_source.lang == "CUDA" and "<" in name and ">" in name:
532+
if kernel_source.lang in ["CUDA", "NVCUDA"] and "<" in name and ">" in name:
533533
kernel_string, name = wrap_templated_kernel(kernel_string, name)
534534

535535
#collect everything we know about this instance and return it

kernel_tuner/nvcuda.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -199,21 +199,30 @@ def compile(self, kernel_instance):
199199
self.compiler_options.append(f"--gpu-architecture=compute_{self.cc}")
200200

201201
err, program = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"CUDAProgram", 0, [], [])
202-
error_check(err)
203-
err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options)
204-
error_check(err)
205-
err, size = nvrtc.nvrtcGetPTXSize(program)
206-
error_check(err)
207-
buff = b' ' * size
208-
err = nvrtc.nvrtcGetPTX(program, buff)
209-
error_check(err)
210-
err, self.current_module = cuda.cuModuleLoadData(np.char.array(buff))
211-
if err == cuda.CUresult.CUDA_ERROR_INVALID_PTX:
212-
raise SkippableFailure("uses too much shared data")
213-
else:
202+
try:
214203
error_check(err)
215-
err, self.func = cuda.cuModuleGetFunction(self.current_module, str.encode(kernel_name))
216-
error_check(err)
204+
err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options)
205+
error_check(err)
206+
err, size = nvrtc.nvrtcGetPTXSize(program)
207+
error_check(err)
208+
buff = b' ' * size
209+
err = nvrtc.nvrtcGetPTX(program, buff)
210+
error_check(err)
211+
err, self.current_module = cuda.cuModuleLoadData(np.char.array(buff))
212+
if err == cuda.CUresult.CUDA_ERROR_INVALID_PTX:
213+
raise SkippableFailure("uses too much shared data")
214+
else:
215+
error_check(err)
216+
err, self.func = cuda.cuModuleGetFunction(self.current_module, str.encode(kernel_name))
217+
error_check(err)
218+
219+
except RuntimeError as re:
220+
_, n = nvrtc.nvrtcGetProgramLogSize(program)
221+
log = b' ' * n
222+
nvrtc.nvrtcGetProgramLog(program, log)
223+
print(log.decode('utf-8'))
224+
raise re
225+
217226
return self.func
218227

219228
def start_event(self):

0 commit comments

Comments
 (0)