@@ -685,40 +685,48 @@ def get_templated_typenames(template_parameters, template_arguments):
685685
686686def wrap_templated_kernel (kernel_string , kernel_name ):
687687 """rewrite kernel_string to insert wrapper function for templated kernel"""
688- #parse kernel_name to find template_arguments and real kernel name
688+ # parse kernel_name to find template_arguments and real kernel name
689689 name = kernel_name .split ("<" )[0 ]
690690 template_arguments = re .search (r".*?<(.*)>" , kernel_name , re .S ).group (1 ).split (',' )
691691
692- #parse templated kernel definition
693- #relatively strict regex that does not allow nested template parameters like vector<TF>
694- #within the template parameter list
695- regex = r"template\s*<([^>]*?)>\s*__global__\s+void\s+" + name + r"\s*\((.*?)\)\s*\{"
692+ # parse templated kernel definition
693+ # relatively strict regex that does not allow nested template parameters like vector<TF>
694+ # within the template parameter list
695+ regex = r"template\s*<([^>]*?)>\s*__global__\s+void\s+(__launch_bounds__\([^\)]+?\)\s+)? " + name + r"\s*\((.*?)\)\s*\{"
696696 match = re .search (regex , kernel_string , re .S )
697697 if not match :
698698 raise ValueError ("could not find templated kernel definition" )
699699
700700 template_parameters = match .group (1 ).split (',' )
701- argument_list = match .group (2 ).split (',' )
701+ argument_list = match .group (3 ).split (',' )
702702 argument_list = [s .strip () for s in argument_list ] #remove extra whitespace around 'type name' strings
703703
704704 type_list , name_list = split_argument_list (argument_list )
705705
706706 templated_typenames = get_templated_typenames (template_parameters , template_arguments )
707707 apply_template_typenames (type_list , templated_typenames )
708708
709- #replace __global__ with __device__ in the templated kernel definition
710- #could do a more precise replace, but __global__ cannot be used elsewhere in the definition
709+ # replace __global__ with __device__ in the templated kernel definition
710+ # could do a more precise replace, but __global__ cannot be used elsewhere in the definition
711711 definition = match .group (0 ).replace ("__global__" , "__device__" )
712712
713- #generate code for the compile-time template instantiation
713+ # there is a __launch_bounds__() group that is matched
714+ launch_bounds = ""
715+ if match .group (2 ):
716+ print (f"found launch bounds: { match .group (2 )= } " )
717+
718+ definition = definition .replace (match .group (2 ), " " )
719+ launch_bounds = match .group (2 )
720+
721+ # generate code for the compile-time template instantiation
714722 template_instantiation = f"template __device__ void { kernel_name } (" + ", " .join (type_list ) + ");\n "
715723
716- #generate code for the wrapper kernel
724+ # generate code for the wrapper kernel
717725 new_arg_list = ", " .join ([" " .join ((a , b )) for a , b in zip (type_list , name_list )])
718- wrapper_function = "\n extern \" C\" __global__ void " + name + "_wrapper(" + new_arg_list + ") {\n " + \
726+ wrapper_function = "\n extern \" C\" __global__ void " + launch_bounds + name + "_wrapper(" + new_arg_list + ") {\n " + \
719727 kernel_name + "(" + ", " .join (name_list ) + ");\n }\n "
720728
721- #copy kernel_string, replace definition and append template instantiation and wrapper function
729+ # copy kernel_string, replace definition and append template instantiation and wrapper function
722730 new_kernel_string = kernel_string [:]
723731 new_kernel_string = new_kernel_string .replace (match .group (0 ), definition )
724732 new_kernel_string += "\n " + template_instantiation
0 commit comments