Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,25 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
"ttg.total-num-warps"))
numWarps = totalNumWarps.getInt();

int numCTAs = 1;
if (auto module = funcOp->getParentOfType<ModuleOp>()) {
if (auto moduleAttr =
module->getAttrOfType<IntegerAttr>(triton::gpu::AttrNumCTAsName))
numCTAs = moduleAttr.getInt();
}

// Set `nvvm.maxnreg` if it was specified on the module.
if (Attribute maxnregAttr =
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);

// Do we want to do this for nCTAs == 1 whenever sm >= 90?
if (numCTAs > 1) {
// Request a specific number of CTAs per cluster in the generated PTX.
newFuncOp->setAttr(NVVM::NVVMDialect::getClusterDimAttrName(),
rewriter.getDenseI32ArrayAttr(numCTAs));
}

// Set an attribute for reqntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),
Expand Down
13 changes: 0 additions & 13 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,18 +297,6 @@ def compile(src, target=None, options=None, _env_vars=None):

metadata["cache_dir"] = fn_cache_manager.cache_dir
metadata["triton_version"] = __version__
cluster_dims = getattr(options, "cluster_dims", None)
if cluster_dims is None:
num_ctas = getattr(options, "num_ctas", None)
if num_ctas is None:
num_ctas = 1
cluster_dims = (num_ctas, 1, 1)
if not isinstance(cluster_dims, (list, tuple)):
cluster_dims = (cluster_dims, )
cluster_dims = tuple(cluster_dims)
if len(cluster_dims) < 3:
cluster_dims = cluster_dims + (1, ) * (3 - len(cluster_dims))
metadata["cluster_dims"] = cluster_dims
# run compilation pipeline and populate metadata
stages = dict()
backend.add_stages(stages, options, src.language)
Expand Down Expand Up @@ -435,7 +423,6 @@ def __init__(self, src, metadata_group, hash):
from collections import namedtuple
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
metadata = json.loads(metadata_path.read_text())
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
target = metadata['target']
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv='compute-capability=90 ptx-version=81' --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' | FileCheck %s

module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: @test_cluster_attr
// CHECK: nvvm.cluster_dim = array<i32: 4>
// CHECK: nvvm.kernel = 1 : ui1
// CHECK: nvvm.reqntid = array<i32: 128>
tt.func @test_cluster_attr(%lb : index, %A : !tt.ptr<f16>) {
tt.return
}
}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
Expand Down
4 changes: 0 additions & 4 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class HIPOptions:
num_stages: int = 2
num_ctas: int = 1
extern_libs: dict = None
cluster_dims: tuple = (1, 1, 1)
debug: bool = False
sanitize_overflow: bool = True
arch: str = None
Expand Down Expand Up @@ -146,9 +145,6 @@ def pack_metadata(self, metadata):
metadata.num_warps,
metadata.num_ctas,
metadata.shared,
metadata.cluster_dims[0],
metadata.cluster_dims[1],
metadata.cluster_dims[2],
)

def get_codegen_implementation(self, options):
Expand Down
8 changes: 4 additions & 4 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def format_of(ty):

#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}

static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
if (gridX * gridY * gridZ == 0)
return;
hipDeviceptr_t global_scratch = 0;
Expand Down Expand Up @@ -632,8 +632,8 @@ def format_of(ty):
}}

// extract kernel metadata
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
int num_warps, num_ctas, shared_memory;
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
return NULL;
}}
// extract launch metadata
Expand All @@ -657,7 +657,7 @@ def format_of(ty):
{newline.join(tensor_desc_decls)}
{newline.join(ptr_decls)}
{newline.join(float_storage_decls)}
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});

if(launch_exit_hook != Py_None){{
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
Expand Down
7 changes: 0 additions & 7 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,6 @@ def make_ttir(cls, mod, metadata, opt):
@classmethod
@track
def make_ttgir(cls, mod, metadata, opt, properties):
cluster_info = intel.ClusterInfo()
if opt.cluster_dims is not None:
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]

# Annotate module with information required by subsequent transformations.
pm = ir.pass_manager(mod.context)
pm.enable_debug()
Expand Down Expand Up @@ -303,7 +297,6 @@ def make_ttgir(cls, mod, metadata, opt, properties):
intel.passes.ttgpuir.add_optimize_reduction_locality(pm)
intel.passes.arith.add_arith_emulate_unsupported_floats(pm, ["bf16"], "f32")
pm.run(mod, 'make_ttgir')
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
return mod

def gluon_to_ttgir(self, src, metadata, options):
Expand Down
10 changes: 0 additions & 10 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,16 +738,6 @@ def format_of(ty):
int threads_per_warp = PyLong_AsLong(threads_per_warp_attr);
Py_DECREF(threads_per_warp_attr);
// extract cluster dims
PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims");
if (!PyTuple_Check(kernel_metadata)) {{
PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple");
return NULL;
}}
int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0));
int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1));
int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2));
Py_DECREF(clusterDim);
// extract launch metadata
if (launch_enter_hook != Py_None){{
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@

namespace mlir::triton::gpu::intel {

// Used by Triton runtime
struct ClusterInfo {
ClusterInfo() = default;
unsigned clusterDimX = 1u;
unsigned clusterDimY = 1u;
unsigned clusterDimZ = 1u;
};

/// Split barrier scope
enum SplitBarrierScope { None = 0, Workgroup = 1, Subgroup = 2 };

Expand Down
13 changes: 0 additions & 13 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,6 @@ void init_triton_intel(py::module &&m) {
init_triton_intel_passes_ttgpuir(passes.def_submodule("ttgpuir"));
init_triton_intel_passes_arith(passes.def_submodule("arith"));

// cluster info
py::class_<gpu::intel::ClusterInfo>(m, "ClusterInfo")
.def(py::init<>())
.def_readwrite("clusterDimX", &gpu::intel::ClusterInfo::clusterDimX)
.def_readwrite("clusterDimY", &gpu::intel::ClusterInfo::clusterDimY)
.def_readwrite("clusterDimZ", &gpu::intel::ClusterInfo::clusterDimZ)
.def("__repr__", [](gpu::intel::ClusterInfo &self) {
std::ostringstream oss;
oss << "(" << self.clusterDimX << ", " << self.clusterDimY << ", "
<< self.clusterDimZ << ")";
return oss.str();
});

// Split barrier scope enum
py::enum_<gpu::intel::SplitBarrierScope>(m, "SplitBarrierScope")
.value("none", gpu::intel::SplitBarrierScope::None)
Expand Down
7 changes: 0 additions & 7 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,6 @@ def pack_metadata(self, metadata):
metadata.num_warps,
metadata.num_ctas,
metadata.shared,
metadata.cluster_dims[0],
metadata.cluster_dims[1],
metadata.cluster_dims[2],
)

def get_codegen_implementation(self, options):
Expand Down Expand Up @@ -317,8 +314,6 @@ def make_ttgir(mod, metadata, opt, capability):
passes.common.add_canonicalizer(pm)

pm.run(mod, 'make_ttgir')
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
metadata["cluster_dims"] = (opt.num_ctas, 1, 1)
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
return mod

Expand All @@ -338,8 +333,6 @@ def gluon_to_ttgir(self, src, metadata, options, capability):
passes.ttgpuir.add_combine_tensor_select_and_if(pm)

pm.run(mod, 'gluon_to_ttgir')
# num_ctas == 16 is non-portable. Does work for H100 and B200 tho
metadata["cluster_dims"] = (options.num_ctas, 1, 1)
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
return mod

Expand Down
18 changes: 8 additions & 10 deletions third_party/nvidia/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,11 @@ defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
cuTensorMapEncodeTiled);

static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
maxActiveClusters = -1;
int clusterDim = -1, maxActiveClusters = -1;
int shared = 0;
CUfunction func;

if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX,
&clusterDimY, &clusterDimZ)) {
if (!PyArg_ParseTuple(args, "Kii", &func, &shared, &clusterDim)) {
return NULL;
}

Expand All @@ -251,13 +249,13 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {

CUlaunchAttribute launchAttr[1];
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
launchAttr[0].value.clusterDim.x = clusterDimX;
launchAttr[0].value.clusterDim.y = clusterDimY;
launchAttr[0].value.clusterDim.z = clusterDimZ;
launchAttr[0].value.clusterDim.x = clusterDim;
launchAttr[0].value.clusterDim.y = 1;
launchAttr[0].value.clusterDim.z = 1;
CUlaunchConfig config;
config.gridDimX = clusterDimX;
config.gridDimY = maxActiveBlocks * clusterDimY;
config.gridDimZ = clusterDimZ;
config.gridDimX = clusterDim * maxActiveBlocks;
config.gridDimY = 1;
config.gridDimZ = 1;
config.blockDimX = 128;
config.blockDimY = 1;
config.blockDimZ = 1;
Expand Down
26 changes: 10 additions & 16 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import operator
import os
import subprocess
import triton
Expand Down Expand Up @@ -339,7 +338,7 @@ def format_of(ty):
}}
#endif

static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
void *params[] = {{ {', '.join(params)} }};
if (gridX*gridY*gridZ > 0) {{
// 4 attributes that we can currently pass maximum
Expand All @@ -349,16 +348,10 @@ def format_of(ty):
cuLaunchKernelExHandle = getLaunchKernelExHandle();
}}
CUlaunchConfig config;
config.gridDimX = gridX;
config.gridDimX = gridX * num_ctas;
config.gridDimY = gridY;
config.gridDimZ = gridZ;

if (num_ctas != 1) {{
config.gridDimX *= clusterDimX;
config.gridDimY *= clusterDimY;
config.gridDimZ *= clusterDimZ;
}}

config.blockDimX = 32 * num_warps;
config.blockDimY = 1;
config.blockDimZ = 1;
Expand All @@ -382,9 +375,9 @@ def format_of(ty):
if (num_ctas != 1) {{
CUlaunchAttribute clusterAttr = {{}};
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
clusterAttr.value.clusterDim.x = clusterDimX;
clusterAttr.value.clusterDim.y = clusterDimY;
clusterAttr.value.clusterDim.z = clusterDimZ;
clusterAttr.value.clusterDim.x = num_ctas;
clusterAttr.value.clusterDim.y = 1;
clusterAttr.value.clusterDim.z = 1;
launchAttr[num_attrs] = clusterAttr;
++num_attrs;

Expand All @@ -395,6 +388,7 @@ def format_of(ty):
++num_attrs;
}}

// num_ctas == 16 is non-portable. Does work for H100 and B200 tho
config.numAttrs = num_attrs;
if (num_ctas == 16) {{
CUDA_CHECK(cuFuncSetAttribute(
Expand Down Expand Up @@ -540,8 +534,8 @@ def format_of(ty):
return NULL;
}}

int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
int num_warps, num_ctas, shared_memory;
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
return NULL;
}}
Expand Down Expand Up @@ -577,7 +571,7 @@ def format_of(ty):
{newline.join(tma_decls)}
{newline.join(float_storage_decls)}
Py_BEGIN_ALLOW_THREADS;
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
Py_END_ALLOW_THREADS;
if (PyErr_Occurred()) {{
return NULL;
Expand Down Expand Up @@ -719,7 +713,7 @@ def __init__(self, src, metadata):
libraries=libraries,
)

self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
self.num_ctas = getattr(metadata, "num_ctas", 1)
self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
self.global_scratch_size = metadata.global_scratch_size
self.global_scratch_align = metadata.global_scratch_align
Expand Down
2 changes: 1 addition & 1 deletion third_party/proton/tutorials/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def metadata_fn(
grid_x, grid_y, grid_z = unpack_grid(grid)
num_warps = metadata.num_warps
num_stages = metadata.num_stages
cluster_x, cluster_y, cluster_z = metadata.cluster_dims
cluster_x, cluster_y, cluster_z = unpack_grid((metadata.num_ctas, ))
shared_memory = metadata.shared
M, K = args["a_ptr"].shape
K, N = args["b_ptr"].shape
Expand Down
Loading