Skip to content

Commit efd7b71

Browse files
Reland upstream commit 9b75018 (#5594)
#5527 Adds matching change to Intel's backend.
2 parents f80d8c2 + 16e01f8 commit efd7b71

File tree

13 files changed

+49
-93
lines changed

13 files changed

+49
-93
lines changed

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,25 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
179179
"ttg.total-num-warps"))
180180
numWarps = totalNumWarps.getInt();
181181

182+
int numCTAs = 1;
183+
if (auto module = funcOp->getParentOfType<ModuleOp>()) {
184+
if (auto moduleAttr =
185+
module->getAttrOfType<IntegerAttr>(triton::gpu::AttrNumCTAsName))
186+
numCTAs = moduleAttr.getInt();
187+
}
188+
182189
// Set `nvvm.maxnreg` if it was specified on the module.
183190
if (Attribute maxnregAttr =
184191
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
185192
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
186193

194+
// Do we want to do this for nCTAs == 1 whenever sm >= 90?
195+
if (numCTAs > 1) {
196+
// Request a specific number of CTAs per cluster in the generated PTX.
197+
newFuncOp->setAttr(NVVM::NVVMDialect::getClusterDimAttrName(),
198+
rewriter.getDenseI32ArrayAttr(numCTAs));
199+
}
200+
187201
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
188202
// for `nvvm.annotation` metadata.
189203
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),

python/triton/compiler/compiler.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -297,18 +297,6 @@ def compile(src, target=None, options=None, _env_vars=None):
297297

298298
metadata["cache_dir"] = fn_cache_manager.cache_dir
299299
metadata["triton_version"] = __version__
300-
cluster_dims = getattr(options, "cluster_dims", None)
301-
if cluster_dims is None:
302-
num_ctas = getattr(options, "num_ctas", None)
303-
if num_ctas is None:
304-
num_ctas = 1
305-
cluster_dims = (num_ctas, 1, 1)
306-
if not isinstance(cluster_dims, (list, tuple)):
307-
cluster_dims = (cluster_dims, )
308-
cluster_dims = tuple(cluster_dims)
309-
if len(cluster_dims) < 3:
310-
cluster_dims = cluster_dims + (1, ) * (3 - len(cluster_dims))
311-
metadata["cluster_dims"] = cluster_dims
312300
# run compilation pipeline and populate metadata
313301
stages = dict()
314302
backend.add_stages(stages, options, src.language)
@@ -435,7 +423,6 @@ def __init__(self, src, metadata_group, hash):
435423
from collections import namedtuple
436424
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
437425
metadata = json.loads(metadata_path.read_text())
438-
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
439426
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
440427
target = metadata['target']
441428
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv='compute-capability=90 ptx-version=81' --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' | FileCheck %s
22

3+
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
4+
// CHECK-LABEL: @test_cluster_attr
5+
// CHECK: nvvm.cluster_dim = array<i32: 4>
6+
// CHECK: nvvm.kernel = 1 : ui1
7+
// CHECK: nvvm.reqntid = array<i32: 128>
8+
tt.func @test_cluster_attr(%lb : index, %A : !tt.ptr<f16>) {
9+
tt.return
10+
}
11+
}
12+
13+
// -----
14+
315
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
416
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
517
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>

third_party/amd/backend/compiler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class HIPOptions:
3434
num_stages: int = 2
3535
num_ctas: int = 1
3636
extern_libs: dict = None
37-
cluster_dims: tuple = (1, 1, 1)
3837
debug: bool = False
3938
sanitize_overflow: bool = True
4039
arch: str = None
@@ -146,9 +145,6 @@ def pack_metadata(self, metadata):
146145
metadata.num_warps,
147146
metadata.num_ctas,
148147
metadata.shared,
149-
metadata.cluster_dims[0],
150-
metadata.cluster_dims[1],
151-
metadata.cluster_dims[2],
152148
)
153149

154150
def get_codegen_implementation(self, options):

third_party/amd/backend/driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def format_of(ty):
487487
488488
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
489489
490-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
490+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
491491
if (gridX * gridY * gridZ == 0)
492492
return;
493493
hipDeviceptr_t global_scratch = 0;
@@ -632,8 +632,8 @@ def format_of(ty):
632632
}}
633633
634634
// extract kernel metadata
635-
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
636-
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
635+
int num_warps, num_ctas, shared_memory;
636+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
637637
return NULL;
638638
}}
639639
// extract launch metadata
@@ -657,7 +657,7 @@ def format_of(ty):
657657
{newline.join(tensor_desc_decls)}
658658
{newline.join(ptr_decls)}
659659
{newline.join(float_storage_decls)}
660-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
660+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
661661
662662
if(launch_exit_hook != Py_None){{
663663
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);

third_party/intel/backend/compiler.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,6 @@ def make_ttir(cls, mod, metadata, opt):
246246
@classmethod
247247
@track
248248
def make_ttgir(cls, mod, metadata, opt, properties):
249-
cluster_info = intel.ClusterInfo()
250-
if opt.cluster_dims is not None:
251-
cluster_info.clusterDimX = opt.cluster_dims[0]
252-
cluster_info.clusterDimY = opt.cluster_dims[1]
253-
cluster_info.clusterDimZ = opt.cluster_dims[2]
254-
255249
# Annotate module with information required by subsequent transformations.
256250
pm = ir.pass_manager(mod.context)
257251
pm.enable_debug()
@@ -303,7 +297,6 @@ def make_ttgir(cls, mod, metadata, opt, properties):
303297
intel.passes.ttgpuir.add_optimize_reduction_locality(pm)
304298
intel.passes.arith.add_arith_emulate_unsupported_floats(pm, ["bf16"], "f32")
305299
pm.run(mod, 'make_ttgir')
306-
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
307300
return mod
308301

309302
def gluon_to_ttgir(self, src, metadata, options):

third_party/intel/backend/driver.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -738,16 +738,6 @@ def format_of(ty):
738738
int threads_per_warp = PyLong_AsLong(threads_per_warp_attr);
739739
Py_DECREF(threads_per_warp_attr);
740740
741-
// extract cluster dims
742-
PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims");
743-
if (!PyTuple_Check(kernel_metadata)) {{
744-
PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple");
745-
return NULL;
746-
}}
747-
int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0));
748-
int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1));
749-
int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2));
750-
Py_DECREF(clusterDim);
751741
// extract launch metadata
752742
if (launch_enter_hook != Py_None){{
753743
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@
1313

1414
namespace mlir::triton::gpu::intel {
1515

16-
// Used by Triton runtime
17-
struct ClusterInfo {
18-
ClusterInfo() = default;
19-
unsigned clusterDimX = 1u;
20-
unsigned clusterDimY = 1u;
21-
unsigned clusterDimZ = 1u;
22-
};
23-
2416
/// Split barrier scope
2517
enum SplitBarrierScope { None = 0, Workgroup = 1, Subgroup = 2 };
2618

third_party/intel/triton_xpu.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,6 @@ void init_triton_intel(py::module &&m) {
136136
init_triton_intel_passes_ttgpuir(passes.def_submodule("ttgpuir"));
137137
init_triton_intel_passes_arith(passes.def_submodule("arith"));
138138

139-
// cluster info
140-
py::class_<gpu::intel::ClusterInfo>(m, "ClusterInfo")
141-
.def(py::init<>())
142-
.def_readwrite("clusterDimX", &gpu::intel::ClusterInfo::clusterDimX)
143-
.def_readwrite("clusterDimY", &gpu::intel::ClusterInfo::clusterDimY)
144-
.def_readwrite("clusterDimZ", &gpu::intel::ClusterInfo::clusterDimZ)
145-
.def("__repr__", [](gpu::intel::ClusterInfo &self) {
146-
std::ostringstream oss;
147-
oss << "(" << self.clusterDimX << ", " << self.clusterDimY << ", "
148-
<< self.clusterDimZ << ")";
149-
return oss.str();
150-
});
151-
152139
// Split barrier scope enum
153140
py::enum_<gpu::intel::SplitBarrierScope>(m, "SplitBarrierScope")
154141
.value("none", gpu::intel::SplitBarrierScope::None)

third_party/nvidia/backend/compiler.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,6 @@ def pack_metadata(self, metadata):
205205
metadata.num_warps,
206206
metadata.num_ctas,
207207
metadata.shared,
208-
metadata.cluster_dims[0],
209-
metadata.cluster_dims[1],
210-
metadata.cluster_dims[2],
211208
)
212209

213210
def get_codegen_implementation(self, options):
@@ -317,8 +314,6 @@ def make_ttgir(mod, metadata, opt, capability):
317314
passes.common.add_canonicalizer(pm)
318315

319316
pm.run(mod, 'make_ttgir')
320-
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
321-
metadata["cluster_dims"] = (opt.num_ctas, 1, 1)
322317
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
323318
return mod
324319

@@ -338,8 +333,6 @@ def gluon_to_ttgir(self, src, metadata, options, capability):
338333
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
339334

340335
pm.run(mod, 'gluon_to_ttgir')
341-
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
342-
metadata["cluster_dims"] = (options.num_ctas, 1, 1)
343336
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
344337
return mod
345338

0 commit comments

Comments
 (0)