Skip to content

Commit 98d1896

Browse files
Merge commit '9b750186115b04267de6bc10d38476557bad0a53'
2 parents d59f085 + 9b75018 commit 98d1896

File tree

31 files changed

+1038
-255
lines changed

31 files changed

+1038
-255
lines changed

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,25 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
179179
"ttg.total-num-warps"))
180180
numWarps = totalNumWarps.getInt();
181181

182+
int numCTAs = 1;
183+
if (auto module = funcOp->getParentOfType<ModuleOp>()) {
184+
if (auto moduleAttr =
185+
module->getAttrOfType<IntegerAttr>(triton::gpu::AttrNumCTAsName))
186+
numCTAs = moduleAttr.getInt();
187+
}
188+
182189
// Set `nvvm.maxnreg` if it was specified on the module.
183190
if (Attribute maxnregAttr =
184191
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
185192
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
186193

194+
// Do we want to do this for nCTAs == 1 whenever sm >= 90?
195+
if (numCTAs > 1) {
196+
// Request a specific number of CTAs per cluster in the generated PTX.
197+
newFuncOp->setAttr(NVVM::NVVMDialect::getClusterDimAttrName(),
198+
rewriter.getDenseI32ArrayAttr(numCTAs));
199+
}
200+
187201
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
188202
// for `nvvm.annotation` metadata.
189203
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),

python/test/gluon/test_frontend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,6 +3006,12 @@ def print_num_warps():
30063006
print("num_warps", num_warps)
30073007

30083008

3009+
@gluon.jit
3010+
def print_num_ctas():
3011+
num_ctas: ttgl.constexpr = ttgl.num_ctas()
3012+
print("num_ctas", num_ctas)
3013+
3014+
30093015
@filecheck_test
30103016
@gluon.jit
30113017
def test_get_num_warps():
@@ -3030,6 +3036,15 @@ def test_get_num_warps():
30303036
], [1, 2, 8], [24, 24, 24])
30313037

30323038

3039+
@filecheck_test
3040+
@gluon.jit
3041+
def test_num_ctas():
3042+
# CHECK-LABEL: test_num_ctas
3043+
# CHECK: tt.func private @{{.*}}print_num_ctas
3044+
# CHECK-NEXT: arith.constant 1 : i32
3045+
print_num_ctas()
3046+
3047+
30333048
def test_mismatch_shape_and_layout_rank():
30343049

30353050
@gluon.jit

python/triton/compiler/compiler.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -294,18 +294,6 @@ def compile(src, target=None, options=None, _env_vars=None):
294294

295295
metadata["cache_dir"] = fn_cache_manager.cache_dir
296296
metadata["triton_version"] = __version__
297-
cluster_dims = getattr(options, "cluster_dims", None)
298-
if cluster_dims is None:
299-
num_ctas = getattr(options, "num_ctas", None)
300-
if num_ctas is None:
301-
num_ctas = 1
302-
cluster_dims = (num_ctas, 1, 1)
303-
if not isinstance(cluster_dims, (list, tuple)):
304-
cluster_dims = (cluster_dims, )
305-
cluster_dims = tuple(cluster_dims)
306-
if len(cluster_dims) < 3:
307-
cluster_dims = cluster_dims + (1, ) * (3 - len(cluster_dims))
308-
metadata["cluster_dims"] = cluster_dims
309297
# run compilation pipeline and populate metadata
310298
stages = dict()
311299
backend.add_stages(stages, options, src.language)
@@ -432,7 +420,6 @@ def __init__(self, src, metadata_group, hash):
432420
from collections import namedtuple
433421
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
434422
metadata = json.loads(metadata_path.read_text())
435-
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
436423
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
437424
target = metadata['target']
438425
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
full,
5353
gather,
5454
num_warps,
55+
num_ctas,
5556
histogram,
5657
inline_asm_elementwise,
5758
join,

python/triton/experimental/gluon/language/_core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
"static_range",
7575
"tuple",
7676
"tuple_type",
77+
"num_ctas",
7778
]
7879

7980
T = TypeVar("T")
@@ -525,6 +526,14 @@ def num_warps(_semantic=None, _generator=None):
525526
return _semantic.num_warps(_generator)
526527

527528

529+
@builtin
530+
def num_ctas(_semantic=None):
531+
"""
532+
Returns the number of CTAs in the current kernel
533+
"""
534+
return _semantic.num_ctas()
535+
536+
528537
@builtin
529538
def thread_barrier(_semantic=None):
530539
"""

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,9 @@ def warp_specialize(self, functions_and_args, worker_num_warps: Sequence[int], w
551551
return
552552
return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
553553

554+
def num_ctas(self):
555+
return ttgl.constexpr(self.builder.options.num_ctas)
556+
554557
def num_warps(self, generator):
555558
if generator.caller_context is not None:
556559
assert isinstance(generator.caller_context, GluonCallerContext)

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2645,3 +2645,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
26452645
tt.return
26462646
}
26472647
}
2648+
2649+
// -----
2650+
2651+
// We had a bug where DotOp lowering treated any input where shape[1] == 1 as an
2652+
// outer product and rejected it. This was incorrect in 3D tensors, since
2653+
// the dimension to look at would have been shape[2].
2654+
2655+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [32, 1, 1], instrShape = [1, 16, 8]}>
2656+
#dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>
2657+
#dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>
2658+
2659+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
2660+
// CHECK-LABEL: batched_dot_3d
2661+
tt.func public @batched_dot_3d(
2662+
%arg0: tensor<32x1x32xf16, #dot_operand_a>,
2663+
%arg1: tensor<32x32x32xf16, #dot_operand_b>
2664+
) {
2665+
%cst = arith.constant dense<0.000000e+00> : tensor<32x1x32xf32, #mma>
2666+
// CHECK: llvm.inline_asm
2667+
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
2668+
%result = tt.dot %arg0, %arg1, %cst, inputPrecision = tf32 :
2669+
tensor<32x1x32xf16, #dot_operand_a> * tensor<32x32x32xf16, #dot_operand_b> -> tensor<32x1x32xf32, #mma>
2670+
tt.return
2671+
}
2672+
}

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv='compute-capability=90 ptx-version=81' --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' | FileCheck %s
22

3+
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
4+
// CHECK-LABEL: @test_cluster_attr
5+
// CHECK: nvvm.cluster_dim = array<i32: 4>
6+
// CHECK: nvvm.kernel = 1 : ui1
7+
// CHECK: nvvm.reqntid = array<i32: 128>
8+
tt.func @test_cluster_attr(%lb : index, %A : !tt.ptr<f16>) {
9+
tt.return
10+
}
11+
}
12+
13+
// -----
14+
315
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
416
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
517
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
@@ -74,6 +86,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
7486

7587
// -----
7688

89+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [16, 2], instrShape = [16, 256, 16]}>
90+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
91+
#smem = #ttg.shared_memory
92+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
93+
// CHECK-LABEL: @warp_group_dot_bf16_32_warps
94+
tt.func @warp_group_dot_bf16_32_warps(
95+
%a: !ttg.memdesc<256x128xbf16, #shared, #smem>,
96+
%b: !ttg.memdesc<128x512xbf16, #shared, #smem>,
97+
%acc: tensor<256x512xf32, #mma>) {
98+
%res = ttng.warp_group_dot %a, %b, %acc {inputPrecision = 0 : i32, isAsync = true} :
99+
!ttg.memdesc<256x128xbf16, #shared, #smem> * !ttg.memdesc<128x512xbf16, #shared, #smem> -> tensor<256x512xf32, #mma>
100+
// CHECK: nvgpu.wgmma {{.*}} k = 16 : i32, layoutA = 1 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32}
101+
tt.return
102+
}
103+
}
104+
105+
// -----
106+
77107
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
78108
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
79109
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>

test/NVWS/aref-tmem-insertion.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -558,15 +558,17 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
558558
// CHECK-NEXT: aref.create
559559
// CHECK-NEXT: aref.put.enter
560560
%result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
561-
scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 {
562-
%0 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
561+
%5 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %token) -> (!ttg.async.token) : i32 {
562+
%0 = ttg.local_alloc %arg1 {ttg.partition = array<i32: 0>} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
563563
%1 = tt.descriptor_load %arg2[%arg3, %arg3] {ttg.partition = array<i32: 2>} : !tt.tensordesc<tensor<64x128xf16, #shared>> -> tensor<64x128xf16, #blocked1>
564564
%2 = arith.addf %1, %1 {ttg.partition = array<i32: 0>} : tensor<64x128xf16, #blocked1>
565565
%3 = ttg.local_alloc %2 {ttg.partition = array<i32: 0>} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem>
566566
// CHECK: aref.buffer
567-
%4 = ttng.tc_gen5_mma %0, %3, %result[%token], %true, %true {ttg.partition = array<i32: 1>} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
568-
} {tt.num_stages = 2 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 18 : i32}
567+
%4 = ttng.tc_gen5_mma %0, %3, %result[%arg4], %true, %true {ttg.partition = array<i32: 1>} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
568+
scf.yield %4 : !ttg.async.token
569+
} {ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 1>], tt.num_stages = 2 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 18 : i32}
569570
// CHECK: aref.put.exit
571+
ttng.tmem_load %result[%5] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
570572
tt.return
571573
}
572574

test/NVWS/assign_stage_phase.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,3 +674,55 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
674674
tt.return
675675
}
676676
}
677+
678+
// -----
679+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
680+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
681+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
682+
#smem = #ttg.shared_memory
683+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
684+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
685+
// CHECK-LABEL: @for_loop_control_operand_ppg
686+
tt.func @for_loop_control_operand_ppg(%lb: i32, %ub: i32, %step: i32, %ptr0: !tt.ptr<i32>) {
687+
%true = arith.constant true
688+
%arefBuf = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>
689+
%aref = nvws.aref.create %arefBuf : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>
690+
%_0, %tok = nvws.aref.put.enter %aref : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
691+
// CHECK: put.enter
692+
// CHECK-NEXT: [[RET:%.*]]:5 = scf.for
693+
%tok0 = scf.for %iv0 = %lb to %ub step %step iter_args(%tok1 = %tok) -> (!ttg.async.token) : i32 {
694+
// CHECK-NEXT: tt.addptr {{.*}} {ttg.partition = array<i32: 0, 1, 2>}
695+
// CHECK-NEXT: tt.load {{.*}} {ttg.partition = array<i32: 0, 1, 2>}
696+
// CHECK-NEXT: "lb1"({{.*}}) {ttg.partition = array<i32: 0, 1, 2>}
697+
// CHECK-NEXT: "step1"({{.*}}) {ttg.partition = array<i32: 0, 1, 2>}
698+
%ptrub = tt.addptr %ptr0, %iv0 {ttg.partition = array<i32: 1, 2>} : !tt.ptr<i32>, i32
699+
%ub1 = tt.load %ptrub {ttg.partition = array<i32: 1, 2>} : !tt.ptr<i32>
700+
%lb1 = "lb1"(%iv0) {ttg.partition = array<i32: 1, 2>} : (i32) -> i32
701+
%step1 = "step1"(%iv0) {ttg.partition = array<i32: 1, 2>} : (i32) -> i32
702+
// CHECK-NEXT: [[RET1:%.*]]:3 = scf.for
703+
%tok5 = scf.for %iv = %lb1 to %ub1 step %step1 iter_args(%tok2 = %tok1) -> (!ttg.async.token) : i32 {
704+
%sA = "load1"(%iv) {ttg.partition = array<i32: 1>} : (i32) -> !ttg.memdesc<128x64xf32, #shared, #smem>
705+
%sB = "load2"(%iv) {ttg.partition = array<i32: 1>} : (i32) -> !ttg.memdesc<64x128xf32, #shared, #smem>
706+
%buf = nvws.aref.buffer %aref, %tok2 {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
707+
ttng.tc_gen5_mma %sA, %sB, %buf, %true, %true {ttg.partition = array<i32: 2>} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x128xf32, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
708+
scf.yield {ttg.partition = array<i32: 1, 2>} %tok2 : !ttg.async.token
709+
} {ttg.partition = array<i32: 1, 2>, ttg.partition.outputs = [array<i32: 2>]}
710+
// CHECK: scf.yield
711+
// CHECK-NEXT: {ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 2>, array<i32: 0, 2>, array<i32: 2>]}
712+
// CHECK-NEXT: nvws.aref.put.exit {{.*}}[[[RET1]]#1]
713+
nvws.aref.put.exit %aref, %tok5 [#nvws.async_op<tc5mma>] {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
714+
%_1, %token_2 = nvws.aref.get.enter %aref {ttg.partition = array<i32: 1>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
715+
nvws.aref.get.exit %aref, %token_2 [#nvws.async_op<none>] {ttg.partition = array<i32: 1>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
716+
%buf1, %tok6 = nvws.aref.put.enter %aref {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
717+
// CHECK: aref.put.enter
718+
// CHECK-NEXT: scf.yield
719+
scf.yield {ttg.partition = array<i32: 1, 2>} %tok6 : !ttg.async.token
720+
// CHECK-NEXT: {tt.warp_specialize, ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 2>, array<i32: 0, 2>, array<i32: 2>, array<i32: 0, 1>, array<i32: 0, 1>]}
721+
} {tt.warp_specialize, ttg.partition = array<i32: 1, 2>, ttg.partition.outputs = [array<i32: 2>]}
722+
// CHECK-NEXT: aref.put.exit {{.*}}[[[RET]]#1]
723+
nvws.aref.put.exit %aref, %tok0 [#nvws.async_op<tc5mma>] : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
724+
%_2, %token_2 = nvws.aref.get.enter %aref : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
725+
nvws.aref.get.exit %aref, %token_2 [#nvws.async_op<none>] : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
726+
tt.return
727+
}
728+
}

0 commit comments

Comments
 (0)