@@ -225,7 +225,7 @@ def format_of(ty):
225225 }[ty_to_cpp (ty )]
226226
227227 args_format = '' .join ([format_of (ty ) for ty in signature .values ()])
228- format = "iiiKKOOOO " + args_format
228+ format = "piiiKKOOOO " + args_format
229229 signature = ',' .join (map (_serialize_signature , signature .values ()))
230230 signature = list (filter (bool , signature .split (',' )))
231231 signature = {i : s for i , s in enumerate (signature )}
@@ -267,6 +267,12 @@ def format_of(ty):
267267 unsigned int blockDimY, unsigned int blockDimZ, \\
268268 unsigned int sharedMemBytes, hipStream_t stream, \\
269269 void **kernelParams, void **extra) \\
270+ FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, hipFunction_t f, \\
271+ unsigned int gridDimX, unsigned int gridDimY, \\
272+ unsigned int gridDimZ, unsigned int blockDimX, \\
273+ unsigned int blockDimY, unsigned int blockDimZ, \\
274+ unsigned int sharedMemBytes, hipStream_t stream, \\
275+ void **kernelParams, void **extra) \\
270276 FOR_EACH_ERR_FN(hipPointerGetAttribute, void *data, \\
271277 hipPointer_attribute attribute, hipDeviceptr_t ptr)
272278
@@ -338,14 +344,18 @@ def format_of(ty):
338344
339345#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
340346
341- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{ ', ' + arg_decls if len (arg_decls ) > 0 else '' } ) {{
347+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{ ', ' + arg_decls if len (arg_decls ) > 0 else '' } ) {{
342348 // printf("_launch hip kernel\\ n");
343349 hipDeviceptr_t global_scratch = 0;
344350 void *params[] = {{ { ', ' .join (params )} }};
351+ if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
352+ HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, { warp_size } *num_warps, 1, 1, shared_memory, stream, params, 0));
353+ return;
354+ }}
345355 if (gridX*gridY*gridZ > 0) {{
346- HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, { warp_size } *num_warps, 1, 1, shared_memory, stream, params, 0));
347- }}
356+ HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, { warp_size } *num_warps, 1, 1, shared_memory, stream, params, 0));
348357 }}
358+ }}
349359
350360typedef struct _DevicePtrInfo {{
351361 hipDeviceptr_t dev_ptr;
@@ -398,12 +408,14 @@ def format_of(ty):
398408 int gridX, gridY, gridZ;
399409 uint64_t _stream;
400410 uint64_t _function;
411+ int launch_cooperative_grid;
401412 PyObject *launch_enter_hook = NULL;
402413 PyObject *launch_exit_hook = NULL;
403414 PyObject *kernel_metadata = NULL;
404415 PyObject *launch_metadata = NULL;
405416 { ' ' .join ([f"{ _extracted_type (ty )} _arg{ i } ; " for i , ty in signature .items ()])}
406- if(!PyArg_ParseTuple(args, \" { format } \" , &gridX, &gridY, &gridZ, &_stream, &_function,
417+ if(!PyArg_ParseTuple(args, \" { format } \" , &launch_cooperative_grid,
418+ &gridX, &gridY, &gridZ, &_stream, &_function,
407419 &kernel_metadata, &launch_metadata,
408420 &launch_enter_hook, &launch_exit_hook { args_list } )) {{
409421 return NULL;
@@ -426,7 +438,7 @@ def format_of(ty):
426438
427439 // raise exception asap
428440 { "; " .join ([f"DevicePtrInfo ptr_info{ i } = getPointer(_arg{ i } , { i } ); if (!ptr_info{ i } .valid) return NULL;" if ty [0 ] == "*" else "" for i , ty in signature .items ()])} ;
429- _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{ ', ' + ', ' .join (internal_args_list ) if len (internal_args_list ) > 0 else '' } );
441+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{ ', ' + ', ' .join (internal_args_list ) if len (internal_args_list ) > 0 else '' } );
430442
431443 if(launch_exit_hook != Py_None){{
432444 PyObject* args = Py_BuildValue("(O)", launch_metadata);
@@ -482,9 +494,10 @@ def __init__(self, src, metadata):
482494 src = make_launcher (constants , signature , metadata .warp_size )
483495 mod = compile_module_from_src (src , "__triton_launcher" )
484496 self .launch = mod .launch
497+ self .launch_cooperative_grid = metadata .launch_cooperative_grid
485498
486499 def __call__ (self , * args ):
487- self .launch (* args )
500+ self .launch (self . launch_cooperative_grid , * args )
488501
489502
490503class HIPDriver (GPUDriver ):
0 commit comments