Skip to content

Commit 8427f69

Browse files
authored
[AMD][Launcher] Support cooperative grid launch (#5700)
This change is a follow up to triton-lang/triton#5381 and is intended to add grid synchronization similar to what cooperative groups do for AMD. This PR adds support for Cooperative Grid Launch for AMD using the HIP API `hipModuleLaunchCooperativeKernel`: https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
1 parent b27b9d5 commit 8427f69

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

python/test/unit/language/test_core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,6 @@ def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
17311731

17321732

17331733
@pytest.mark.interpreter
1734-
@pytest.mark.skipif(is_hip(), reason="Not implemented for AMD At this moment")
17351734
def test_load_scope_sem_coop_grid_cta_one(device):
17361735

17371736
@triton.jit

third_party/amd/backend/compiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ class HIPOptions:
5252
default_dot_input_precision: str = "ieee"
5353
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
5454
enable_fp_fusion: bool = True
55-
# TODO: Implement cooperative grid launch for AMD:
56-
# See: https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
5755
launch_cooperative_grid: bool = False
5856
matrix_instr_nonkdim: int = 0
5957
kpack: int = 1

third_party/amd/backend/driver.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def format_of(ty):
225225
}[ty_to_cpp(ty)]
226226

227227
args_format = ''.join([format_of(ty) for ty in signature.values()])
228-
format = "iiiKKOOOO" + args_format
228+
format = "piiiKKOOOO" + args_format
229229
signature = ','.join(map(_serialize_signature, signature.values()))
230230
signature = list(filter(bool, signature.split(',')))
231231
signature = {i: s for i, s in enumerate(signature)}
@@ -267,6 +267,12 @@ def format_of(ty):
267267
unsigned int blockDimY, unsigned int blockDimZ, \\
268268
unsigned int sharedMemBytes, hipStream_t stream, \\
269269
void **kernelParams, void **extra) \\
270+
FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, hipFunction_t f, \\
271+
unsigned int gridDimX, unsigned int gridDimY, \\
272+
unsigned int gridDimZ, unsigned int blockDimX, \\
273+
unsigned int blockDimY, unsigned int blockDimZ, \\
274+
unsigned int sharedMemBytes, hipStream_t stream, \\
275+
void **kernelParams, void **extra) \\
270276
FOR_EACH_ERR_FN(hipPointerGetAttribute, void *data, \\
271277
hipPointer_attribute attribute, hipDeviceptr_t ptr)
272278
@@ -338,14 +344,18 @@ def format_of(ty):
338344
339345
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
340346
341-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
347+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
342348
// printf("_launch hip kernel\\n");
343349
hipDeviceptr_t global_scratch = 0;
344350
void *params[] = {{ {', '.join(params)} }};
351+
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
352+
HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
353+
return;
354+
}}
345355
if (gridX*gridY*gridZ > 0) {{
346-
HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
347-
}}
356+
HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
348357
}}
358+
}}
349359
350360
typedef struct _DevicePtrInfo {{
351361
hipDeviceptr_t dev_ptr;
@@ -398,12 +408,14 @@ def format_of(ty):
398408
int gridX, gridY, gridZ;
399409
uint64_t _stream;
400410
uint64_t _function;
411+
int launch_cooperative_grid;
401412
PyObject *launch_enter_hook = NULL;
402413
PyObject *launch_exit_hook = NULL;
403414
PyObject *kernel_metadata = NULL;
404415
PyObject *launch_metadata = NULL;
405416
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
406-
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function,
417+
if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
418+
&gridX, &gridY, &gridZ, &_stream, &_function,
407419
&kernel_metadata, &launch_metadata,
408420
&launch_enter_hook, &launch_exit_hook {args_list})) {{
409421
return NULL;
@@ -426,7 +438,7 @@ def format_of(ty):
426438
427439
// raise exception asap
428440
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
429-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
441+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
430442
431443
if(launch_exit_hook != Py_None){{
432444
PyObject* args = Py_BuildValue("(O)", launch_metadata);
@@ -482,9 +494,10 @@ def __init__(self, src, metadata):
482494
src = make_launcher(constants, signature, metadata.warp_size)
483495
mod = compile_module_from_src(src, "__triton_launcher")
484496
self.launch = mod.launch
497+
self.launch_cooperative_grid = metadata.launch_cooperative_grid
485498

486499
def __call__(self, *args):
487-
self.launch(*args)
500+
self.launch(self.launch_cooperative_grid, *args)
488501

489502

490503
class HIPDriver(GPUDriver):

0 commit comments

Comments
 (0)