Skip to content

Commit bccf48a

Browse files
Revert "[BACKEND] Emit reqnctapercluster (#8645)"
This reverts commit 9b75018.
1 parent 98d1896 commit bccf48a

File tree

9 files changed

+55
-49
lines changed

9 files changed

+55
-49
lines changed

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -179,25 +179,11 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
179179
"ttg.total-num-warps"))
180180
numWarps = totalNumWarps.getInt();
181181

182-
int numCTAs = 1;
183-
if (auto module = funcOp->getParentOfType<ModuleOp>()) {
184-
if (auto moduleAttr =
185-
module->getAttrOfType<IntegerAttr>(triton::gpu::AttrNumCTAsName))
186-
numCTAs = moduleAttr.getInt();
187-
}
188-
189182
// Set `nvvm.maxnreg` if it was specified on the module.
190183
if (Attribute maxnregAttr =
191184
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
192185
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
193186

194-
// Do we want to do this for nCTAs == 1 whenever sm >= 90?
195-
if (numCTAs > 1) {
196-
// Request a specific number of CTAs per cluster in the generated PTX.
197-
newFuncOp->setAttr(NVVM::NVVMDialect::getClusterDimAttrName(),
198-
rewriter.getDenseI32ArrayAttr(numCTAs));
199-
}
200-
201187
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
202188
// for `nvvm.annotation` metadata.
203189
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),

python/triton/compiler/compiler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,18 @@ def compile(src, target=None, options=None, _env_vars=None):
294294

295295
metadata["cache_dir"] = fn_cache_manager.cache_dir
296296
metadata["triton_version"] = __version__
297+
cluster_dims = getattr(options, "cluster_dims", None)
298+
if cluster_dims is None:
299+
num_ctas = getattr(options, "num_ctas", None)
300+
if num_ctas is None:
301+
num_ctas = 1
302+
cluster_dims = (num_ctas, 1, 1)
303+
if not isinstance(cluster_dims, (list, tuple)):
304+
cluster_dims = (cluster_dims, )
305+
cluster_dims = tuple(cluster_dims)
306+
if len(cluster_dims) < 3:
307+
cluster_dims = cluster_dims + (1, ) * (3 - len(cluster_dims))
308+
metadata["cluster_dims"] = cluster_dims
297309
# run compilation pipeline and populate metadata
298310
stages = dict()
299311
backend.add_stages(stages, options, src.language)
@@ -420,6 +432,7 @@ def __init__(self, src, metadata_group, hash):
420432
from collections import namedtuple
421433
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
422434
metadata = json.loads(metadata_path.read_text())
435+
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
423436
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
424437
target = metadata['target']
425438
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,5 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv='compute-capability=90 ptx-version=81' --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' | FileCheck %s
22

3-
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
4-
// CHECK-LABEL: @test_cluster_attr
5-
// CHECK: nvvm.cluster_dim = array<i32: 4>
6-
// CHECK: nvvm.kernel = 1 : ui1
7-
// CHECK: nvvm.reqntid = array<i32: 128>
8-
tt.func @test_cluster_attr(%lb : index, %A : !tt.ptr<f16>) {
9-
tt.return
10-
}
11-
}
12-
13-
// -----
14-
153
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
164
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
175
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>

third_party/amd/backend/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class HIPOptions:
3434
num_stages: int = 2
3535
num_ctas: int = 1
3636
extern_libs: dict = None
37+
cluster_dims: tuple = (1, 1, 1)
3738
debug: bool = False
3839
sanitize_overflow: bool = True
3940
arch: str = None
@@ -137,6 +138,9 @@ def pack_metadata(self, metadata):
137138
metadata.num_warps,
138139
metadata.num_ctas,
139140
metadata.shared,
141+
metadata.cluster_dims[0],
142+
metadata.cluster_dims[1],
143+
metadata.cluster_dims[2],
140144
)
141145

142146
def get_codegen_implementation(self, options):

third_party/amd/backend/driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def format_of(ty):
440440
441441
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
442442
443-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
443+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
444444
hipDeviceptr_t global_scratch = 0;
445445
void *params[] = {{ {', '.join(params)} }};
446446
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -548,8 +548,8 @@ def format_of(ty):
548548
{' '.join(float_storage_decls)}
549549
550550
// extract kernel metadata
551-
int num_warps, num_ctas, shared_memory;
552-
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
551+
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
552+
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
553553
return NULL;
554554
}}
555555
// extract launch metadata
@@ -571,7 +571,7 @@ def format_of(ty):
571571
572572
// raise exception asap
573573
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
574-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
574+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
575575
576576
if(launch_exit_hook != Py_None){{
577577
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);

third_party/nvidia/backend/compiler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def pack_metadata(self, metadata):
204204
metadata.num_warps,
205205
metadata.num_ctas,
206206
metadata.shared,
207+
metadata.cluster_dims[0],
208+
metadata.cluster_dims[1],
209+
metadata.cluster_dims[2],
207210
)
208211

209212
def get_codegen_implementation(self, options):
@@ -313,6 +316,8 @@ def make_ttgir(mod, metadata, opt, capability):
313316
passes.common.add_canonicalizer(pm)
314317

315318
pm.run(mod, 'make_ttgir')
319+
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
320+
metadata["cluster_dims"] = (opt.num_ctas, 1, 1)
316321
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
317322
return mod
318323

@@ -331,6 +336,8 @@ def gluon_to_ttgir(self, src, metadata, options, capability):
331336
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
332337

333338
pm.run(mod, 'gluon_to_ttgir')
339+
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
340+
metadata["cluster_dims"] = (options.num_ctas, 1, 1)
334341
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
335342
return mod
336343

third_party/nvidia/backend/driver.c

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,13 @@ defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
231231
cuTensorMapEncodeTiled);
232232

233233
static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
234-
int clusterDim = -1, maxActiveClusters = -1;
234+
int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
235+
maxActiveClusters = -1;
235236
int shared = 0;
236237
CUfunction func;
237238

238-
if (!PyArg_ParseTuple(args, "Kii", &func, &shared, &clusterDim)) {
239+
if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX,
240+
&clusterDimY, &clusterDimZ)) {
239241
return NULL;
240242
}
241243

@@ -248,13 +250,13 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
248250

249251
CUlaunchAttribute launchAttr[1];
250252
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
251-
launchAttr[0].value.clusterDim.x = clusterDim;
252-
launchAttr[0].value.clusterDim.y = 1;
253-
launchAttr[0].value.clusterDim.z = 1;
253+
launchAttr[0].value.clusterDim.x = clusterDimX;
254+
launchAttr[0].value.clusterDim.y = clusterDimY;
255+
launchAttr[0].value.clusterDim.z = clusterDimZ;
254256
CUlaunchConfig config;
255-
config.gridDimX = clusterDim * maxActiveBlocks;
256-
config.gridDimY = 1;
257-
config.gridDimZ = 1;
257+
config.gridDimX = clusterDimX;
258+
config.gridDimY = maxActiveBlocks * clusterDimY;
259+
config.gridDimZ = clusterDimZ;
258260
config.blockDimX = 128;
259261
config.blockDimY = 1;
260262
config.blockDimZ = 1;

third_party/nvidia/backend/driver.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import operator
23
import os
34
import subprocess
45
import triton
@@ -338,7 +339,7 @@ def format_of(ty):
338339
}}
339340
#endif
340341
341-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
342+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
342343
void *params[] = {{ {', '.join(params)} }};
343344
if (gridX*gridY*gridZ > 0) {{
344345
// 4 attributes that we can currently pass maximum
@@ -348,10 +349,16 @@ def format_of(ty):
348349
cuLaunchKernelExHandle = getLaunchKernelExHandle();
349350
}}
350351
CUlaunchConfig config;
351-
config.gridDimX = gridX * num_ctas;
352+
config.gridDimX = gridX;
352353
config.gridDimY = gridY;
353354
config.gridDimZ = gridZ;
354355
356+
if (num_ctas != 1) {{
357+
config.gridDimX *= clusterDimX;
358+
config.gridDimY *= clusterDimY;
359+
config.gridDimZ *= clusterDimZ;
360+
}}
361+
355362
config.blockDimX = 32 * num_warps;
356363
config.blockDimY = 1;
357364
config.blockDimZ = 1;
@@ -375,9 +382,9 @@ def format_of(ty):
375382
if (num_ctas != 1) {{
376383
CUlaunchAttribute clusterAttr = {{}};
377384
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
378-
clusterAttr.value.clusterDim.x = num_ctas;
379-
clusterAttr.value.clusterDim.y = 1;
380-
clusterAttr.value.clusterDim.z = 1;
385+
clusterAttr.value.clusterDim.x = clusterDimX;
386+
clusterAttr.value.clusterDim.y = clusterDimY;
387+
clusterAttr.value.clusterDim.z = clusterDimZ;
381388
launchAttr[num_attrs] = clusterAttr;
382389
++num_attrs;
383390
@@ -388,7 +395,6 @@ def format_of(ty):
388395
++num_attrs;
389396
}}
390397
391-
// num_ctas == 16 is non-portable. Does work for H100 and B200 tho
392398
config.numAttrs = num_attrs;
393399
if (num_ctas == 16) {{
394400
CUDA_CHECK(cuFuncSetAttribute(
@@ -534,8 +540,8 @@ def format_of(ty):
534540
return NULL;
535541
}}
536542
537-
int num_warps, num_ctas, shared_memory;
538-
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
543+
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
544+
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
539545
PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
540546
return NULL;
541547
}}
@@ -571,7 +577,7 @@ def format_of(ty):
571577
{newline.join(tma_decls)}
572578
{newline.join(float_storage_decls)}
573579
Py_BEGIN_ALLOW_THREADS;
574-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
580+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
575581
Py_END_ALLOW_THREADS;
576582
if (PyErr_Occurred()) {{
577583
return NULL;
@@ -713,7 +719,7 @@ def __init__(self, src, metadata):
713719
libraries=libraries,
714720
)
715721

716-
self.num_ctas = getattr(metadata, "num_ctas", 1)
722+
self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
717723
self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
718724
self.global_scratch_size = metadata.global_scratch_size
719725
self.global_scratch_align = metadata.global_scratch_align

third_party/proton/tutorials/matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def metadata_fn(
2424
grid_x, grid_y, grid_z = unpack_grid(grid)
2525
num_warps = metadata.num_warps
2626
num_stages = metadata.num_stages
27-
cluster_x, cluster_y, cluster_z = unpack_grid((metadata.num_ctas, ))
27+
cluster_x, cluster_y, cluster_z = metadata.cluster_dims
2828
shared_memory = metadata.shared
2929
M, K = args["a_ptr"].shape
3030
K, N = args["b_ptr"].shape

0 commit comments

Comments
 (0)