Skip to content

Commit 3c058ee

Browse files
authored
[NVIDIA][Launcher] NV Cooperative Grid Launching (CU_LAUNCH_ATTRIBUTE_COOPERATIVE) (#5381)
This change sets the launch grid attribute before calling cuLaunchKernelEx. This change is intended to pair with load/store atomics from triton-lang/triton#5187 and is intended to add grid synchronization similar to what cooperative groups do. @ptillet Any recommendations on the UI for using this in code would be most welcome :-) - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/python/test` for end-to-end tests - [?] This PR does not need a test because: I am not entirely sure how to test the use of one driver API attr versus another for this case yet. I did add a test that exercises the launch_cooperative_grid=True launch flag but I am not confirming that the plumbing triggers the use of the API attr in test, although I did confirm it does offline using an assert. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 43f1ad4 commit 3c058ee

File tree

5 files changed

+89
-9
lines changed

5 files changed

+89
-9
lines changed

python/test/unit/language/test_core.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,48 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
16461646
assert (torch.equal(X, Y))
16471647

16481648

1649+
@pytest.mark.interpreter
1650+
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9 or is_hip(),
1651+
reason="Requires compute capability >= 9 for NV")
1652+
def test_load_scope_sem_coop_grid_cta_not_one(device):
1653+
1654+
@triton.jit
1655+
def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
1656+
numel = 512
1657+
offset = tl.program_id(0) * BLOCK_SIZE
1658+
index = offset
1659+
mask = index < numel
1660+
a = tl.load(ptrs, mask=mask)
1661+
tl.store(ptrs, a)
1662+
1663+
block_size = 128
1664+
data = torch.zeros((128, ), device=device, dtype=torch.float32)
1665+
1666+
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=True)
1667+
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=False)
1668+
1669+
1670+
@pytest.mark.interpreter
1671+
@pytest.mark.skipif(is_hip(), reason="Not implemented for AMD At this moment")
1672+
def test_load_scope_sem_coop_grid_cta_one(device):
1673+
1674+
@triton.jit
1675+
def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
1676+
numel = 512
1677+
offset = tl.program_id(0) * BLOCK_SIZE
1678+
index = offset
1679+
mask = index < numel
1680+
a = tl.load(ptrs, mask=mask)
1681+
tl.store(ptrs, a)
1682+
1683+
block_size = 128
1684+
data = torch.zeros((128, ), device=device, dtype=torch.float32)
1685+
1686+
# Should do nothing different for num_ctas=1 (with coop launch grid)
1687+
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=True)
1688+
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=False)
1689+
1690+
16491691
# ---------------
16501692
# test cast
16511693
# ---------------

python/triton/runtime/jit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def _call_hook(
504504
name = self.fn.__name__
505505
module = self.fn.__module__
506506
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
507-
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})"
507+
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
508508

509509
class JitFunctionInfo:
510510

@@ -524,6 +524,7 @@ def __init__(self, module, name, jit_function):
524524
'num_ctas': options.num_ctas,
525525
'num_stages': options.num_stages,
526526
'enable_fp_fusion': options.enable_fp_fusion,
527+
'launch_cooperative_grid': options.launch_cooperative_grid,
527528
'extern_libs': options.extern_libs,
528529
'configs': configs,
529530
'specialization_data': specialization_data,

third_party/amd/backend/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class HIPOptions:
5656
default_dot_input_precision: str = "ieee"
5757
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
5858
enable_fp_fusion: bool = True
59+
# TODO: Implement cooperative grid launch for AMD:
60+
# See: https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
61+
launch_cooperative_grid: bool = False
5962
matrix_instr_nonkdim: int = 0
6063
kpack: int = 1
6164
allow_flush_denorm: bool = False

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class CUDAOptions:
112112
cluster_dims: tuple = (1, 1, 1)
113113
ptx_version: int = None
114114
enable_fp_fusion: bool = True
115+
launch_cooperative_grid: bool = False
115116
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
116117
deprecated_fp8_dtypes: Tuple[str] = ()
117118
default_dot_input_precision: str = "tf32"

third_party/nvidia/backend/driver.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def format_of(ty):
159159

160160
signature = {k: v for k, v in signature.items() if v != 'constexpr'}
161161
args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
162-
format = "iiiKKOOOOO" + args_format
162+
format = "iiiKKpOOOOO" + args_format
163163
signature = ','.join(signature.values()).replace('[', '').replace(']', '')
164164
signature = list(filter(bool, signature.split(',')))
165165
signature = {i: s for i, s in enumerate(signature)}
@@ -227,19 +227,50 @@ def format_of(ty):
227227
return cuLaunchKernelExHandle;
228228
}}
229229
230-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
230+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
231231
void *params[] = {{ {', '.join(params)} }};
232232
if (gridX*gridY*gridZ > 0) {{
233-
if (num_ctas == 1) {{
233+
if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
234234
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
235+
}} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{
236+
CUlaunchAttribute launchAttr[1];
237+
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
238+
launchAttr[0] = coopAttr;
239+
240+
CUlaunchConfig config;
241+
config.gridDimX = gridX;
242+
config.gridDimY = gridY;
243+
config.gridDimZ = gridZ;
244+
config.blockDimX = 32 * num_warps;
245+
config.blockDimY = 1;
246+
config.blockDimZ = 1;
247+
config.sharedMemBytes = shared_memory;
248+
config.hStream = stream;
249+
config.attrs = launchAttr;
250+
config.numAttrs = 1;
251+
252+
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
253+
if (cuLaunchKernelExHandle == NULL) {{
254+
cuLaunchKernelExHandle = getLaunchKernelExHandle();
255+
}}
256+
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
257+
235258
}} else {{
236-
CUlaunchAttribute launchAttr[2];
259+
CUlaunchAttribute launchAttr[3];
237260
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
238261
launchAttr[0].value.clusterDim.x = clusterDimX;
239262
launchAttr[0].value.clusterDim.y = clusterDimY;
240263
launchAttr[0].value.clusterDim.z = clusterDimZ;
241264
launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
242265
launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
266+
267+
unsigned numAttrs = 2;
268+
if (0 != launch_cooperative_grid) {{
269+
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
270+
launchAttr[2] = coopAttr;
271+
numAttrs = 3;
272+
}}
273+
243274
CUlaunchConfig config;
244275
config.gridDimX = gridX * clusterDimX;
245276
config.gridDimY = gridY * clusterDimY;
@@ -250,7 +281,7 @@ def format_of(ty):
250281
config.sharedMemBytes = shared_memory;
251282
config.hStream = stream;
252283
config.attrs = launchAttr;
253-
config.numAttrs = 2;
284+
config.numAttrs = numAttrs;
254285
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
255286
if (cuLaunchKernelExHandle == NULL) {{
256287
cuLaunchKernelExHandle = getLaunchKernelExHandle();
@@ -375,14 +406,15 @@ def format_of(ty):
375406
int gridX, gridY, gridZ;
376407
uint64_t _stream;
377408
uint64_t _function;
409+
int launch_cooperative_grid;
378410
PyObject *launch_enter_hook = NULL;
379411
PyObject *launch_exit_hook = NULL;
380412
PyObject *kernel_metadata = NULL;
381413
PyObject *launch_metadata = NULL;
382414
PyObject *global_scratch_obj = NULL;
383415
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
384416
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
385-
&_stream, &_function, &global_scratch_obj,
417+
&_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
386418
&kernel_metadata, &launch_metadata,
387419
&launch_enter_hook, &launch_exit_hook{args_list})) {{
388420
return NULL;
@@ -416,7 +448,7 @@ def format_of(ty):
416448
{"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
417449
{"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])};
418450
Py_BEGIN_ALLOW_THREADS;
419-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
451+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
420452
Py_END_ALLOW_THREADS;
421453
if (PyErr_Occurred()) {{
422454
return NULL;
@@ -471,6 +503,7 @@ def __init__(self, src, metadata):
471503
self.launch = mod.launch
472504
self.global_scratch_size = metadata.global_scratch_size
473505
self.global_scratch_align = metadata.global_scratch_align
506+
self.launch_cooperative_grid = metadata.launch_cooperative_grid
474507

475508
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
476509
if self.global_scratch_size > 0:
@@ -479,7 +512,7 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
479512
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
480513
else:
481514
global_scratch = None
482-
self.launch(gridX, gridY, gridZ, stream, function, global_scratch, *args)
515+
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
483516

484517

485518
class CudaDriver(GPUDriver):

0 commit comments

Comments
 (0)