diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 6c4bd9c1a7..efcd76fe7e 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -179,11 +179,25 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { "ttg.total-num-warps")) numWarps = totalNumWarps.getInt(); + int numCTAs = 1; + if (auto module = funcOp->getParentOfType()) { + if (auto moduleAttr = + module->getAttrOfType(triton::gpu::AttrNumCTAsName)) + numCTAs = moduleAttr.getInt(); + } + // Set `nvvm.maxnreg` if it was specified on the module. if (Attribute maxnregAttr = funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName)) newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr); + // Do we want to do this for nCTAs == 1 whenever sm >= 90? + if (numCTAs > 1) { + // Request a specific number of CTAs per cluster in the generated PTX. + newFuncOp->setAttr(NVVM::NVVMDialect::getClusterDimAttrName(), + rewriter.getDenseI32ArrayAttr(numCTAs)); + } + // Set an attribute for reqntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(), diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 1fd15908b0..28bd42cca9 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -297,18 +297,6 @@ def compile(src, target=None, options=None, _env_vars=None): metadata["cache_dir"] = fn_cache_manager.cache_dir metadata["triton_version"] = __version__ - cluster_dims = getattr(options, "cluster_dims", None) - if cluster_dims is None: - num_ctas = getattr(options, "num_ctas", None) - if num_ctas is None: - num_ctas = 1 - cluster_dims = (num_ctas, 1, 1) - if not isinstance(cluster_dims, (list, tuple)): - cluster_dims = (cluster_dims, ) - cluster_dims = tuple(cluster_dims) - if len(cluster_dims) < 3: - cluster_dims = cluster_dims + (1, ) * (3 - len(cluster_dims)) - metadata["cluster_dims"] = cluster_dims # run compilation pipeline and populate metadata stages = dict() backend.add_stages(stages, options, src.language) @@ -435,7 +423,6 @@ def __init__(self, src, metadata_group, hash): from collections import namedtuple metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) metadata = json.loads(metadata_path.read_text()) - metadata['cluster_dims'] = tuple(metadata['cluster_dims']) # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. target = metadata['target'] metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 8f4faaef0b..017e08b39d 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,5 +1,17 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv='compute-capability=90 ptx-version=81' --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' | FileCheck %s +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @test_cluster_attr + // CHECK: nvvm.cluster_dim = array + // CHECK: nvvm.kernel = 1 : ui1 + // CHECK: nvvm.reqntid = array + tt.func @test_cluster_attr(%lb : index, %A : !tt.ptr) { + tt.return + } +} + +// ----- + #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}> diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 98b7ec2f40..887802333d 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -34,7 +34,6 @@ class HIPOptions: num_stages: int = 2 num_ctas: int = 1 extern_libs: dict = None - cluster_dims: tuple = (1, 1, 1) debug: bool = False sanitize_overflow: bool = True arch: str = None @@ -146,9 +145,6 @@ def pack_metadata(self, metadata): metadata.num_warps, metadata.num_ctas, metadata.shared, - metadata.cluster_dims[0], - metadata.cluster_dims[1], - metadata.cluster_dims[2], ) def get_codegen_implementation(self, options): diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 9df8196638..24a0d84e8a 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -487,7 +487,7 @@ def format_of(ty): #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ if (gridX * gridY * gridZ == 0) return; hipDeviceptr_t global_scratch = 0; @@ -632,8 +632,8 @@ def format_of(ty): }} // extract kernel metadata - int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; - if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + int num_warps, num_ctas, shared_memory; + if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{ return NULL; }} // extract launch metadata @@ -657,7 +657,7 @@ def format_of(ty): {newline.join(tensor_desc_decls)} {newline.join(ptr_decls)} {newline.join(float_storage_decls)} - _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); + _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); if(launch_exit_hook != Py_None){{ PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata); diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 7561e50eae..9b6c88a33f 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -246,12 +246,6 @@ def make_ttir(cls, mod, metadata, opt): @classmethod @track def make_ttgir(cls, mod, metadata, opt, properties): - cluster_info = intel.ClusterInfo() - if opt.cluster_dims is not None: - cluster_info.clusterDimX = opt.cluster_dims[0] - cluster_info.clusterDimY = opt.cluster_dims[1] - cluster_info.clusterDimZ = opt.cluster_dims[2] - # Annotate module with information required by subsequent transformations. pm = ir.pass_manager(mod.context) pm.enable_debug() @@ -303,7 +297,6 @@ def make_ttgir(cls, mod, metadata, opt, properties): intel.passes.ttgpuir.add_optimize_reduction_locality(pm) intel.passes.arith.add_arith_emulate_unsupported_floats(pm, ["bf16"], "f32") pm.run(mod, 'make_ttgir') - metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) return mod def gluon_to_ttgir(self, src, metadata, options): diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index be59664675..0faee49b73 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -738,16 +738,6 @@ def format_of(ty): int threads_per_warp = PyLong_AsLong(threads_per_warp_attr); Py_DECREF(threads_per_warp_attr); - // extract cluster dims - PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims"); - if (!PyTuple_Check(kernel_metadata)) {{ - PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple"); - return NULL; - }} - int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0)); - int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1)); - int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2)); - Py_DECREF(clusterDim); // extract launch metadata if (launch_enter_hook != Py_None){{ PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata); diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h index 380ba626b1..0305242572 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h @@ -13,14 +13,6 @@ namespace mlir::triton::gpu::intel { -// Used by Triton runtime -struct ClusterInfo { - ClusterInfo() = default; - unsigned clusterDimX = 1u; - unsigned clusterDimY = 1u; - unsigned clusterDimZ = 1u; -}; - /// Split barrier scope enum SplitBarrierScope { None = 0, Workgroup = 1, Subgroup = 2 }; diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index f7b73475eb..2bdf1c6586 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -136,19 +136,6 @@ void init_triton_intel(py::module &&m) { init_triton_intel_passes_ttgpuir(passes.def_submodule("ttgpuir")); init_triton_intel_passes_arith(passes.def_submodule("arith")); - // cluster info - py::class_(m, "ClusterInfo") - .def(py::init<>()) - .def_readwrite("clusterDimX", &gpu::intel::ClusterInfo::clusterDimX) - .def_readwrite("clusterDimY", &gpu::intel::ClusterInfo::clusterDimY) - .def_readwrite("clusterDimZ", &gpu::intel::ClusterInfo::clusterDimZ) - .def("__repr__", [](gpu::intel::ClusterInfo &self) { - std::ostringstream oss; - oss << "(" << self.clusterDimX << ", " << self.clusterDimY << ", " - << self.clusterDimZ << ")"; - return oss.str(); - }); - // Split barrier scope enum py::enum_(m, "SplitBarrierScope") .value("none", gpu::intel::SplitBarrierScope::None) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 04bdca9650..079d1195d0 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -205,9 +205,6 @@ def pack_metadata(self, metadata): metadata.num_warps, metadata.num_ctas, metadata.shared, - metadata.cluster_dims[0], - metadata.cluster_dims[1], - metadata.cluster_dims[2], ) def get_codegen_implementation(self, options): @@ -317,8 +314,6 @@ def make_ttgir(mod, metadata, opt, capability): passes.common.add_canonicalizer(pm) pm.run(mod, 'make_ttgir') - # num_ctas == 16 is non-portable. Does work for H100 and B200 tho - metadata["cluster_dims"] = (opt.num_ctas, 1, 1) metadata["tensordesc_meta"] = mod.get_tensordesc_metadata() return mod @@ -338,8 +333,6 @@ def gluon_to_ttgir(self, src, metadata, options, capability): passes.ttgpuir.add_combine_tensor_select_and_if(pm) pm.run(mod, 'gluon_to_ttgir') - # num_ctas == 16 is non-portable. Does work for H100 and B200 tho - metadata["cluster_dims"] = (options.num_ctas, 1, 1) metadata["tensordesc_meta"] = mod.get_tensordesc_metadata() return mod diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index f7d23162f0..9603fec8e1 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -232,13 +232,11 @@ defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, cuTensorMapEncodeTiled); static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { - int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, - maxActiveClusters = -1; + int clusterDim = -1, maxActiveClusters = -1; int shared = 0; CUfunction func; - if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, - &clusterDimY, &clusterDimZ)) { + if (!PyArg_ParseTuple(args, "Kii", &func, &shared, &clusterDim)) { return NULL; } @@ -251,13 +249,13 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { CUlaunchAttribute launchAttr[1]; launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; + launchAttr[0].value.clusterDim.x = clusterDim; + launchAttr[0].value.clusterDim.y = 1; + launchAttr[0].value.clusterDim.z = 1; CUlaunchConfig config; - config.gridDimX = clusterDimX; - config.gridDimY = maxActiveBlocks * clusterDimY; - config.gridDimZ = clusterDimZ; + config.gridDimX = clusterDim * maxActiveBlocks; + config.gridDimY = 1; + config.gridDimZ = 1; config.blockDimX = 128; config.blockDimY = 1; config.blockDimZ = 1; diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index d4c6c4325b..02a3219dff 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -1,5 +1,4 @@ import functools -import operator import os import subprocess import triton @@ -339,7 +338,7 @@ def format_of(ty): }} #endif -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ void *params[] = {{ {', '.join(params)} }}; if (gridX*gridY*gridZ > 0) {{ // 4 attributes that we can currently pass maximum @@ -349,16 +348,10 @@ def format_of(ty): cuLaunchKernelExHandle = getLaunchKernelExHandle(); }} CUlaunchConfig config; - config.gridDimX = gridX; + config.gridDimX = gridX * num_ctas; config.gridDimY = gridY; config.gridDimZ = gridZ; - if (num_ctas != 1) {{ - config.gridDimX *= clusterDimX; - config.gridDimY *= clusterDimY; - config.gridDimZ *= clusterDimZ; - }} - config.blockDimX = 32 * num_warps; config.blockDimY = 1; config.blockDimZ = 1; @@ -382,9 +375,9 @@ def format_of(ty): if (num_ctas != 1) {{ CUlaunchAttribute clusterAttr = {{}}; clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - clusterAttr.value.clusterDim.x = clusterDimX; - clusterAttr.value.clusterDim.y = clusterDimY; - clusterAttr.value.clusterDim.z = clusterDimZ; + clusterAttr.value.clusterDim.x = num_ctas; + clusterAttr.value.clusterDim.y = 1; + clusterAttr.value.clusterDim.z = 1; launchAttr[num_attrs] = clusterAttr; ++num_attrs; @@ -395,6 +388,7 @@ def format_of(ty): ++num_attrs; }} + // num_ctas == 16 is non-portable. Does work for H100 and B200 tho config.numAttrs = num_attrs; if (num_ctas == 16) {{ CUDA_CHECK(cuFuncSetAttribute( @@ -540,8 +534,8 @@ def format_of(ty): return NULL; }} - int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; - if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + int num_warps, num_ctas, shared_memory; + if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{ PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); return NULL; }} @@ -577,7 +571,7 @@ def format_of(ty): {newline.join(tma_decls)} {newline.join(float_storage_decls)} Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); + _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); Py_END_ALLOW_THREADS; if (PyErr_Occurred()) {{ return NULL; @@ -719,7 +713,7 @@ def __init__(self, src, metadata): libraries=libraries, ) - self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1) + self.num_ctas = getattr(metadata, "num_ctas", 1) self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta) self.global_scratch_size = metadata.global_scratch_size self.global_scratch_align = metadata.global_scratch_align diff --git a/third_party/proton/tutorials/matmul.py b/third_party/proton/tutorials/matmul.py index 992177d775..e5e09d1dda 100644 --- a/third_party/proton/tutorials/matmul.py +++ b/third_party/proton/tutorials/matmul.py @@ -24,7 +24,7 @@ def metadata_fn( grid_x, grid_y, grid_z = unpack_grid(grid) num_warps = metadata.num_warps num_stages = metadata.num_stages - cluster_x, cluster_y, cluster_z = metadata.cluster_dims + cluster_x, cluster_y, cluster_z = unpack_grid((metadata.num_ctas, )) shared_memory = metadata.shared M, K = args["a_ptr"].shape K, N = args["b_ptr"].shape