Skip to content

Commit 9b75018

Browse files
authored
[BACKEND] Emit reqnctapercluster (#8645)
We now tell LLVM and PTXAS the number of CTAs we are about to use. This allows them to generate better code in most cases. We also clean-up all the remaining uses of `cluster_dims` and we replace them by queries to `num_ctas`.
1 parent 0be6469 commit 9b75018

File tree

9 files changed

+49
-55
lines changed

9 files changed

+49
-55
lines changed

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,25 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
179179
"ttg.total-num-warps"))
180180
numWarps = totalNumWarps.getInt();
181181

182+
int numCTAs = 1;
183+
if (auto module = funcOp->getParentOfType<ModuleOp>()) {
184+
if (auto moduleAttr =
185+
module->getAttrOfType<IntegerAttr>(triton::gpu::AttrNumCTAsName))
186+
numCTAs = moduleAttr.getInt();
187+
}
188+
182189
// Set `nvvm.maxnreg` if it was specified on the module.
183190
if (Attribute maxnregAttr =
184191
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
185192
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
186193

194+
// Do we want to do this for nCTAs == 1 whenever sm >= 90?
195+
if (numCTAs > 1) {
196+
// Request a specific number of CTAs per cluster in the generated PTX.
197+
newFuncOp->setAttr(NVVM::NVVMDialect::getClusterDimAttrName(),
198+
rewriter.getDenseI32ArrayAttr(numCTAs));
199+
}
200+
187201
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
188202
// for `nvvm.annotation` metadata.
189203
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),

python/triton/compiler/compiler.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,18 +283,6 @@ def compile(src, target=None, options=None, _env_vars=None):
283283
**env_vars,
284284
}
285285
metadata["triton_version"] = __version__
286-
cluster_dims = getattr(options, "cluster_dims", None)
287-
if cluster_dims is None:
288-
num_ctas = getattr(options, "num_ctas", None)
289-
if num_ctas is None:
290-
num_ctas = 1
291-
cluster_dims = (num_ctas, 1, 1)
292-
if not isinstance(cluster_dims, (list, tuple)):
293-
cluster_dims = (cluster_dims, )
294-
cluster_dims = tuple(cluster_dims)
295-
if len(cluster_dims) < 3:
296-
cluster_dims = cluster_dims + (1, ) * (3 - len(cluster_dims))
297-
metadata["cluster_dims"] = cluster_dims
298286
# run compilation pipeline and populate metadata
299287
stages = dict()
300288
backend.add_stages(stages, options, src.language)
@@ -419,7 +407,6 @@ def __init__(self, src, metadata_group, hash):
419407
from collections import namedtuple
420408
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
421409
metadata = json.loads(metadata_path.read_text())
422-
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
423410
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
424411
target = metadata['target']
425412
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv='compute-capability=90 ptx-version=81' --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' | FileCheck %s
22

3+
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
4+
// CHECK-LABEL: @test_cluster_attr
5+
// CHECK: nvvm.cluster_dim = array<i32: 4>
6+
// CHECK: nvvm.kernel = 1 : ui1
7+
// CHECK: nvvm.reqntid = array<i32: 128>
8+
tt.func @test_cluster_attr(%lb : index, %A : !tt.ptr<f16>) {
9+
tt.return
10+
}
11+
}
12+
13+
// -----
14+
315
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
416
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
517
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>

third_party/amd/backend/compiler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class HIPOptions:
3434
num_stages: int = 2
3535
num_ctas: int = 1
3636
extern_libs: dict = None
37-
cluster_dims: tuple = (1, 1, 1)
3837
debug: bool = False
3938
sanitize_overflow: bool = True
4039
arch: str = None
@@ -138,9 +137,6 @@ def pack_metadata(self, metadata):
138137
metadata.num_warps,
139138
metadata.num_ctas,
140139
metadata.shared,
141-
metadata.cluster_dims[0],
142-
metadata.cluster_dims[1],
143-
metadata.cluster_dims[2],
144140
)
145141

146142
def get_codegen_implementation(self, options):

third_party/amd/backend/driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def format_of(ty):
440440
441441
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
442442
443-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
443+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
444444
hipDeviceptr_t global_scratch = 0;
445445
void *params[] = {{ {', '.join(params)} }};
446446
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -548,8 +548,8 @@ def format_of(ty):
548548
{' '.join(float_storage_decls)}
549549
550550
// extract kernel metadata
551-
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
552-
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
551+
int num_warps, num_ctas, shared_memory;
552+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
553553
return NULL;
554554
}}
555555
// extract launch metadata
@@ -571,7 +571,7 @@ def format_of(ty):
571571
572572
// raise exception asap
573573
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
574-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
574+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
575575
576576
if(launch_exit_hook != Py_None){{
577577
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);

third_party/nvidia/backend/compiler.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,6 @@ def pack_metadata(self, metadata):
204204
metadata.num_warps,
205205
metadata.num_ctas,
206206
metadata.shared,
207-
metadata.cluster_dims[0],
208-
metadata.cluster_dims[1],
209-
metadata.cluster_dims[2],
210207
)
211208

212209
def get_codegen_implementation(self, options):
@@ -316,8 +313,6 @@ def make_ttgir(mod, metadata, opt, capability):
316313
passes.common.add_canonicalizer(pm)
317314

318315
pm.run(mod, 'make_ttgir')
319-
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
320-
metadata["cluster_dims"] = (opt.num_ctas, 1, 1)
321316
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
322317
return mod
323318

@@ -336,8 +331,6 @@ def gluon_to_ttgir(self, src, metadata, options, capability):
336331
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
337332

338333
pm.run(mod, 'gluon_to_ttgir')
339-
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
340-
metadata["cluster_dims"] = (options.num_ctas, 1, 1)
341334
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
342335
return mod
343336

third_party/nvidia/backend/driver.c

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,11 @@ defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
203203
cuTensorMapEncodeTiled);
204204

205205
static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
206-
int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
207-
maxActiveClusters = -1;
206+
int clusterDim = -1, maxActiveClusters = -1;
208207
int shared = 0;
209208
CUfunction func;
210209

211-
if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX,
212-
&clusterDimY, &clusterDimZ)) {
210+
if (!PyArg_ParseTuple(args, "Kii", &func, &shared, &clusterDim)) {
213211
return NULL;
214212
}
215213

@@ -222,13 +220,13 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
222220

223221
CUlaunchAttribute launchAttr[1];
224222
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
225-
launchAttr[0].value.clusterDim.x = clusterDimX;
226-
launchAttr[0].value.clusterDim.y = clusterDimY;
227-
launchAttr[0].value.clusterDim.z = clusterDimZ;
223+
launchAttr[0].value.clusterDim.x = clusterDim;
224+
launchAttr[0].value.clusterDim.y = 1;
225+
launchAttr[0].value.clusterDim.z = 1;
228226
CUlaunchConfig config;
229-
config.gridDimX = clusterDimX;
230-
config.gridDimY = maxActiveBlocks * clusterDimY;
231-
config.gridDimZ = clusterDimZ;
227+
config.gridDimX = clusterDim * maxActiveBlocks;
228+
config.gridDimY = 1;
229+
config.gridDimZ = 1;
232230
config.blockDimX = 128;
233231
config.blockDimY = 1;
234232
config.blockDimZ = 1;

third_party/nvidia/backend/driver.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import operator
32
import os
43
import subprocess
54
import triton
@@ -314,7 +313,7 @@ def format_of(ty):
314313
return cuLaunchKernelExHandle;
315314
}}
316315
317-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
316+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
318317
void *params[] = {{ {', '.join(params)} }};
319318
if (gridX*gridY*gridZ > 0) {{
320319
// 4 attributes that we can currently pass maximum
@@ -324,16 +323,10 @@ def format_of(ty):
324323
cuLaunchKernelExHandle = getLaunchKernelExHandle();
325324
}}
326325
CUlaunchConfig config;
327-
config.gridDimX = gridX;
326+
config.gridDimX = gridX * num_ctas;
328327
config.gridDimY = gridY;
329328
config.gridDimZ = gridZ;
330329
331-
if (num_ctas != 1) {{
332-
config.gridDimX *= clusterDimX;
333-
config.gridDimY *= clusterDimY;
334-
config.gridDimZ *= clusterDimZ;
335-
}}
336-
337330
config.blockDimX = 32 * num_warps;
338331
config.blockDimY = 1;
339332
config.blockDimZ = 1;
@@ -357,9 +350,9 @@ def format_of(ty):
357350
if (num_ctas != 1) {{
358351
CUlaunchAttribute clusterAttr = {{}};
359352
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
360-
clusterAttr.value.clusterDim.x = clusterDimX;
361-
clusterAttr.value.clusterDim.y = clusterDimY;
362-
clusterAttr.value.clusterDim.z = clusterDimZ;
353+
clusterAttr.value.clusterDim.x = num_ctas;
354+
clusterAttr.value.clusterDim.y = 1;
355+
clusterAttr.value.clusterDim.z = 1;
363356
launchAttr[num_attrs] = clusterAttr;
364357
++num_attrs;
365358
@@ -370,6 +363,7 @@ def format_of(ty):
370363
++num_attrs;
371364
}}
372365
366+
// num_ctas == 16 is non-portable. Does work for H100 and B200 tho
373367
config.numAttrs = num_attrs;
374368
if (num_ctas == 16) {{
375369
CUDA_CHECK(cuFuncSetAttribute(
@@ -515,8 +509,8 @@ def format_of(ty):
515509
return NULL;
516510
}}
517511
518-
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
519-
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
512+
int num_warps, num_ctas, shared_memory;
513+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
520514
PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
521515
return NULL;
522516
}}
@@ -552,7 +546,7 @@ def format_of(ty):
552546
{newline.join(tma_decls)}
553547
{newline.join(float_storage_decls)}
554548
Py_BEGIN_ALLOW_THREADS;
555-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
549+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
556550
Py_END_ALLOW_THREADS;
557551
if (PyErr_Occurred()) {{
558552
return NULL;
@@ -694,7 +688,7 @@ def __init__(self, src, metadata):
694688
libraries=libraries,
695689
)
696690

697-
self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
691+
self.num_ctas = getattr(metadata, "num_ctas", 1)
698692
self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
699693
self.global_scratch_size = metadata.global_scratch_size
700694
self.global_scratch_align = metadata.global_scratch_align

third_party/proton/tutorials/matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def metadata_fn(
2424
grid_x, grid_y, grid_z = unpack_grid(grid)
2525
num_warps = metadata.num_warps
2626
num_stages = metadata.num_stages
27-
cluster_x, cluster_y, cluster_z = metadata.cluster_dims
27+
cluster_x, cluster_y, cluster_z = unpack_grid((metadata.num_ctas, ))
2828
shared_memory = metadata.shared
2929
M, K = args["a_ptr"].shape
3030
K, N = args["b_ptr"].shape

0 commit comments

Comments
 (0)