Skip to content

Commit 51245e9

Browse files
lezcanoslawblauciak
authored andcommitted
[BACKEND] Emit reqnctapercluster (#8645)
We now tell LLVM and PTXAS the number of CTAs we are about to use. This allows them to generate better code in most cases. We also clean-up all the remaining uses of `cluster_dims` and we replace them by queries to `num_ctas`.
1 parent 0bf934b commit 51245e9

File tree

9 files changed

+49
-55
lines changed

9 files changed

+49
-55
lines changed

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,25 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
179179
"ttg.total-num-warps"))
180180
numWarps = totalNumWarps.getInt();
181181

182+
int numCTAs = 1;
183+
if (auto module = funcOp->getParentOfType<ModuleOp>()) {
184+
if (auto moduleAttr =
185+
module->getAttrOfType<IntegerAttr>(triton::gpu::AttrNumCTAsName))
186+
numCTAs = moduleAttr.getInt();
187+
}
188+
182189
// Set `nvvm.maxnreg` if it was specified on the module.
183190
if (Attribute maxnregAttr =
184191
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
185192
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
186193

194+
// Do we want to do this for nCTAs == 1 whenever sm >= 90?
195+
if (numCTAs > 1) {
196+
// Request a specific number of CTAs per cluster in the generated PTX.
197+
newFuncOp->setAttr(NVVM::NVVMDialect::getClusterDimAttrName(),
198+
rewriter.getDenseI32ArrayAttr(numCTAs));
199+
}
200+
187201
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
188202
// for `nvvm.annotation` metadata.
189203
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),

python/triton/compiler/compiler.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -294,18 +294,6 @@ def compile(src, target=None, options=None, _env_vars=None):
294294

295295
metadata["cache_dir"] = fn_cache_manager.cache_dir
296296
metadata["triton_version"] = __version__
297-
cluster_dims = getattr(options, "cluster_dims", None)
298-
if cluster_dims is None:
299-
num_ctas = getattr(options, "num_ctas", None)
300-
if num_ctas is None:
301-
num_ctas = 1
302-
cluster_dims = (num_ctas, 1, 1)
303-
if not isinstance(cluster_dims, (list, tuple)):
304-
cluster_dims = (cluster_dims, )
305-
cluster_dims = tuple(cluster_dims)
306-
if len(cluster_dims) < 3:
307-
cluster_dims = cluster_dims + (1, ) * (3 - len(cluster_dims))
308-
metadata["cluster_dims"] = cluster_dims
309297
# run compilation pipeline and populate metadata
310298
stages = dict()
311299
backend.add_stages(stages, options, src.language)
@@ -432,7 +420,6 @@ def __init__(self, src, metadata_group, hash):
432420
from collections import namedtuple
433421
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
434422
metadata = json.loads(metadata_path.read_text())
435-
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
436423
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
437424
target = metadata['target']
438425
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv='compute-capability=90 ptx-version=81' --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' | FileCheck %s
22

3+
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
4+
// CHECK-LABEL: @test_cluster_attr
5+
// CHECK: nvvm.cluster_dim = array<i32: 4>
6+
// CHECK: nvvm.kernel = 1 : ui1
7+
// CHECK: nvvm.reqntid = array<i32: 128>
8+
tt.func @test_cluster_attr(%lb : index, %A : !tt.ptr<f16>) {
9+
tt.return
10+
}
11+
}
12+
13+
// -----
14+
315
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
416
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
517
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>

third_party/amd/backend/compiler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class HIPOptions:
3434
num_stages: int = 2
3535
num_ctas: int = 1
3636
extern_libs: dict = None
37-
cluster_dims: tuple = (1, 1, 1)
3837
debug: bool = False
3938
sanitize_overflow: bool = True
4039
arch: str = None
@@ -146,9 +145,6 @@ def pack_metadata(self, metadata):
146145
metadata.num_warps,
147146
metadata.num_ctas,
148147
metadata.shared,
149-
metadata.cluster_dims[0],
150-
metadata.cluster_dims[1],
151-
metadata.cluster_dims[2],
152148
)
153149

154150
def get_codegen_implementation(self, options):

third_party/amd/backend/driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def format_of(ty):
487487
488488
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
489489
490-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
490+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
491491
if (gridX * gridY * gridZ == 0)
492492
return;
493493
hipDeviceptr_t global_scratch = 0;
@@ -632,8 +632,8 @@ def format_of(ty):
632632
}}
633633
634634
// extract kernel metadata
635-
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
636-
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
635+
int num_warps, num_ctas, shared_memory;
636+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
637637
return NULL;
638638
}}
639639
// extract launch metadata
@@ -657,7 +657,7 @@ def format_of(ty):
657657
{newline.join(tensor_desc_decls)}
658658
{newline.join(ptr_decls)}
659659
{newline.join(float_storage_decls)}
660-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
660+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
661661
662662
if(launch_exit_hook != Py_None){{
663663
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);

third_party/nvidia/backend/compiler.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,6 @@ def pack_metadata(self, metadata):
205205
metadata.num_warps,
206206
metadata.num_ctas,
207207
metadata.shared,
208-
metadata.cluster_dims[0],
209-
metadata.cluster_dims[1],
210-
metadata.cluster_dims[2],
211208
)
212209

213210
def get_codegen_implementation(self, options):
@@ -317,8 +314,6 @@ def make_ttgir(mod, metadata, opt, capability):
317314
passes.common.add_canonicalizer(pm)
318315

319316
pm.run(mod, 'make_ttgir')
320-
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
321-
metadata["cluster_dims"] = (opt.num_ctas, 1, 1)
322317
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
323318
return mod
324319

@@ -338,8 +333,6 @@ def gluon_to_ttgir(self, src, metadata, options, capability):
338333
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
339334

340335
pm.run(mod, 'gluon_to_ttgir')
341-
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
342-
metadata["cluster_dims"] = (options.num_ctas, 1, 1)
343336
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
344337
return mod
345338

third_party/nvidia/backend/driver.c

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,11 @@ defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
232232
cuTensorMapEncodeTiled);
233233

234234
static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
235-
int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
236-
maxActiveClusters = -1;
235+
int clusterDim = -1, maxActiveClusters = -1;
237236
int shared = 0;
238237
CUfunction func;
239238

240-
if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX,
241-
&clusterDimY, &clusterDimZ)) {
239+
if (!PyArg_ParseTuple(args, "Kii", &func, &shared, &clusterDim)) {
242240
return NULL;
243241
}
244242

@@ -251,13 +249,13 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
251249

252250
CUlaunchAttribute launchAttr[1];
253251
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
254-
launchAttr[0].value.clusterDim.x = clusterDimX;
255-
launchAttr[0].value.clusterDim.y = clusterDimY;
256-
launchAttr[0].value.clusterDim.z = clusterDimZ;
252+
launchAttr[0].value.clusterDim.x = clusterDim;
253+
launchAttr[0].value.clusterDim.y = 1;
254+
launchAttr[0].value.clusterDim.z = 1;
257255
CUlaunchConfig config;
258-
config.gridDimX = clusterDimX;
259-
config.gridDimY = maxActiveBlocks * clusterDimY;
260-
config.gridDimZ = clusterDimZ;
256+
config.gridDimX = clusterDim * maxActiveBlocks;
257+
config.gridDimY = 1;
258+
config.gridDimZ = 1;
261259
config.blockDimX = 128;
262260
config.blockDimY = 1;
263261
config.blockDimZ = 1;

third_party/nvidia/backend/driver.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import operator
32
import os
43
import subprocess
54
import triton
@@ -339,7 +338,7 @@ def format_of(ty):
339338
}}
340339
#endif
341340
342-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
341+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
343342
void *params[] = {{ {', '.join(params)} }};
344343
if (gridX*gridY*gridZ > 0) {{
345344
// 4 attributes that we can currently pass maximum
@@ -349,16 +348,10 @@ def format_of(ty):
349348
cuLaunchKernelExHandle = getLaunchKernelExHandle();
350349
}}
351350
CUlaunchConfig config;
352-
config.gridDimX = gridX;
351+
config.gridDimX = gridX * num_ctas;
353352
config.gridDimY = gridY;
354353
config.gridDimZ = gridZ;
355354
356-
if (num_ctas != 1) {{
357-
config.gridDimX *= clusterDimX;
358-
config.gridDimY *= clusterDimY;
359-
config.gridDimZ *= clusterDimZ;
360-
}}
361-
362355
config.blockDimX = 32 * num_warps;
363356
config.blockDimY = 1;
364357
config.blockDimZ = 1;
@@ -382,9 +375,9 @@ def format_of(ty):
382375
if (num_ctas != 1) {{
383376
CUlaunchAttribute clusterAttr = {{}};
384377
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
385-
clusterAttr.value.clusterDim.x = clusterDimX;
386-
clusterAttr.value.clusterDim.y = clusterDimY;
387-
clusterAttr.value.clusterDim.z = clusterDimZ;
378+
clusterAttr.value.clusterDim.x = num_ctas;
379+
clusterAttr.value.clusterDim.y = 1;
380+
clusterAttr.value.clusterDim.z = 1;
388381
launchAttr[num_attrs] = clusterAttr;
389382
++num_attrs;
390383
@@ -395,6 +388,7 @@ def format_of(ty):
395388
++num_attrs;
396389
}}
397390
391+
// num_ctas == 16 is non-portable. Does work for H100 and B200 tho
398392
config.numAttrs = num_attrs;
399393
if (num_ctas == 16) {{
400394
CUDA_CHECK(cuFuncSetAttribute(
@@ -540,8 +534,8 @@ def format_of(ty):
540534
return NULL;
541535
}}
542536
543-
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
544-
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
537+
int num_warps, num_ctas, shared_memory;
538+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
545539
PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
546540
return NULL;
547541
}}
@@ -577,7 +571,7 @@ def format_of(ty):
577571
{newline.join(tma_decls)}
578572
{newline.join(float_storage_decls)}
579573
Py_BEGIN_ALLOW_THREADS;
580-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
574+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
581575
Py_END_ALLOW_THREADS;
582576
if (PyErr_Occurred()) {{
583577
return NULL;
@@ -719,7 +713,7 @@ def __init__(self, src, metadata):
719713
libraries=libraries,
720714
)
721715

722-
self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
716+
self.num_ctas = getattr(metadata, "num_ctas", 1)
723717
self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
724718
self.global_scratch_size = metadata.global_scratch_size
725719
self.global_scratch_align = metadata.global_scratch_align

third_party/proton/tutorials/matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def metadata_fn(
2424
grid_x, grid_y, grid_z = unpack_grid(grid)
2525
num_warps = metadata.num_warps
2626
num_stages = metadata.num_stages
27-
cluster_x, cluster_y, cluster_z = metadata.cluster_dims
27+
cluster_x, cluster_y, cluster_z = unpack_grid((metadata.num_ctas, ))
2828
shared_memory = metadata.shared
2929
M, K = args["a_ptr"].shape
3030
K, N = args["b_ptr"].shape

0 commit comments

Comments
 (0)