Skip to content

Commit 672e0c0

Browse files
add support for launch_bounds in templated kernels
1 parent 0bac305 commit 672e0c0

File tree

2 files changed

+79
-12
lines changed

2 files changed

+79
-12
lines changed

kernel_tuner/core.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,40 +685,48 @@ def get_templated_typenames(template_parameters, template_arguments):
685685

686686
def wrap_templated_kernel(kernel_string, kernel_name):
687687
"""rewrite kernel_string to insert wrapper function for templated kernel"""
688-
#parse kernel_name to find template_arguments and real kernel name
688+
# parse kernel_name to find template_arguments and real kernel name
689689
name = kernel_name.split("<")[0]
690690
template_arguments = re.search(r".*?<(.*)>", kernel_name, re.S).group(1).split(',')
691691

692-
#parse templated kernel definition
693-
#relatively strict regex that does not allow nested template parameters like vector<TF>
694-
#within the template parameter list
695-
regex = r"template\s*<([^>]*?)>\s*__global__\s+void\s+" + name + r"\s*\((.*?)\)\s*\{"
692+
# parse templated kernel definition
693+
# relatively strict regex that does not allow nested template parameters like vector<TF>
694+
# within the template parameter list
695+
regex = r"template\s*<([^>]*?)>\s*__global__\s+void\s+(__launch_bounds__\([^\)]+?\)\s+)?" + name + r"\s*\((.*?)\)\s*\{"
696696
match = re.search(regex, kernel_string, re.S)
697697
if not match:
698698
raise ValueError("could not find templated kernel definition")
699699

700700
template_parameters = match.group(1).split(',')
701-
argument_list = match.group(2).split(',')
701+
argument_list = match.group(3).split(',')
702702
argument_list = [s.strip() for s in argument_list] #remove extra whitespace around 'type name' strings
703703

704704
type_list, name_list = split_argument_list(argument_list)
705705

706706
templated_typenames = get_templated_typenames(template_parameters, template_arguments)
707707
apply_template_typenames(type_list, templated_typenames)
708708

709-
#replace __global__ with __device__ in the templated kernel definition
710-
#could do a more precise replace, but __global__ cannot be used elsewhere in the definition
709+
# replace __global__ with __device__ in the templated kernel definition
710+
# could do a more precise replace, but __global__ cannot be used elsewhere in the definition
711711
definition = match.group(0).replace("__global__", "__device__")
712712

713-
#generate code for the compile-time template instantiation
713+
# there is a __launch_bounds__() group that is matched
714+
launch_bounds = ""
715+
if match.group(2):
716+
print(f"found launch bounds: {match.group(2)=}")
717+
718+
definition = definition.replace(match.group(2), " ")
719+
launch_bounds = match.group(2)
720+
721+
# generate code for the compile-time template instantiation
714722
template_instantiation = f"template __device__ void {kernel_name}(" + ", ".join(type_list) + ");\n"
715723

716-
#generate code for the wrapper kernel
724+
# generate code for the wrapper kernel
717725
new_arg_list = ", ".join([" ".join((a, b)) for a, b in zip(type_list, name_list)])
718-
wrapper_function = "\nextern \"C\" __global__ void " + name + "_wrapper(" + new_arg_list + ") {\n " + \
726+
wrapper_function = "\nextern \"C\" __global__ void " + launch_bounds + name + "_wrapper(" + new_arg_list + ") {\n " + \
719727
kernel_name + "(" + ", ".join(name_list) + ");\n}\n"
720728

721-
#copy kernel_string, replace definition and append template instantiation and wrapper function
729+
# copy kernel_string, replace definition and append template instantiation and wrapper function
722730
new_kernel_string = kernel_string[:]
723731
new_kernel_string = new_kernel_string.replace(match.group(0), definition)
724732
new_kernel_string += "\n" + template_instantiation

test/test_core.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,62 @@ def test_wrap_templated_kernel():
219219
#check if original kernel is called
220220
assert "vector_add<float>(c, a, b, n);" in ans
221221

222+
def test_wrap_templated_kernel2():
223+
kernel_string = """
224+
template<typename TF> __global__ void __launch_bounds__(THREADS_PER_BLOCK, BLOCKS_PER_SM) vector_add(TF *c, const TF *__restrict__ a, TF * b , int n) {
225+
auto i = blockIdx.x * block_size_x + threadIdx.x;
226+
if (i<n) {
227+
c[i] = a[i] + b[i];
228+
}
229+
}
230+
"""
231+
kernel_name = "vector_add<float>"
232+
# test no exception is thrown
233+
ans, _ = core.wrap_templated_kernel(kernel_string, kernel_name)
234+
assert True
235+
236+
def test_wrap_templated_kernel3():
237+
kernel_string = """
238+
template<typename TF> __global__ void __launch_bounds__(THREADS_PER_BLOCK, BLOCKS_PER_SM) vector_add1(TF *c, const TF *__restrict__ a, TF * b , int n) {
239+
auto i = blockIdx.x * block_size_x + threadIdx.x;
240+
if (i<n) {
241+
c[i] = a[i] + b[i];
242+
}
243+
}
244+
245+
template<typename TF> __global__ void __launch_bounds__(THREADS_PER_BLOCK, BLOCKS_PER_WRONG) test_vector_add1(TF *a, const TF *__restrict__ a, TF * b , int n) {
246+
auto i = blockIdx.x * block_size_x + threadIdx.x;
247+
if (i<n) {
248+
c[i] = a[i] + b[i];
249+
}
250+
}
251+
"""
252+
kernel_name = "vector_add1<float>"
253+
ans, _ = core.wrap_templated_kernel(kernel_string, kernel_name)
254+
255+
# test that the template wrapper matches the right kernel (the first and not the second)
256+
assert 'extern "C" __global__ void __launch_bounds__(THREADS_PER_BLOCK, BLOCKS_PER_SM) vector_add1_wrapper(float * c, const float *__restrict__ a, float * b, int n)' in ans
257+
258+
259+
def test_wrap_templated_kernel4():
260+
kernel_string = """
261+
template<typename TF> __global__ void __launch_bounds__(THREADS_PER_BLOCK, BLOCKS_PER_WRONG) test_vector_add1(TF *a, const TF *__restrict__ a, TF * b , int n) {
262+
auto i = blockIdx.x * block_size_x + threadIdx.x;
263+
if (i<n) {
264+
c[i] = a[i] + b[i];
265+
}
266+
}
267+
268+
template<typename TF> __global__ void __launch_bounds__(THREADS_PER_BLOCK, BLOCKS_PER_SM) vector_add1(TF *c, const TF *__restrict__ a, TF * b , int n) {
269+
auto i = blockIdx.x * block_size_x + threadIdx.x;
270+
if (i<n) {
271+
c[i] = a[i] + b[i];
272+
}
273+
}
274+
275+
"""
276+
kernel_name = "vector_add1<float>"
277+
ans, _ = core.wrap_templated_kernel(kernel_string, kernel_name)
278+
279+
# test that the template wrapper matches the right kernel (the second not the first)
280+
assert 'extern "C" __global__ void __launch_bounds__(THREADS_PER_BLOCK, BLOCKS_PER_SM) vector_add1_wrapper(float * c, const float *__restrict__ a, float * b, int n)' in ans

0 commit comments

Comments
 (0)