Skip to content

Commit a12938d

Browse files
Merge OpenAI Triton commit 9b75018 (#5522)
This PR changes the Triton base from 618ec40 to 9b75018 (Nov 5). Pass rate: 95.23%
2 parents 8528cf6 + bccf48a commit a12938d

File tree

23 files changed

+989
-200
lines changed

23 files changed

+989
-200
lines changed

python/test/gluon/test_frontend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,6 +3006,12 @@ def print_num_warps():
30063006
print("num_warps", num_warps)
30073007

30083008

3009+
@gluon.jit
3010+
def print_num_ctas():
3011+
num_ctas: ttgl.constexpr = ttgl.num_ctas()
3012+
print("num_ctas", num_ctas)
3013+
3014+
30093015
@filecheck_test
30103016
@gluon.jit
30113017
def test_get_num_warps():
@@ -3030,6 +3036,15 @@ def test_get_num_warps():
30303036
], [1, 2, 8], [24, 24, 24])
30313037

30323038

3039+
@filecheck_test
3040+
@gluon.jit
3041+
def test_num_ctas():
3042+
# CHECK-LABEL: test_num_ctas
3043+
# CHECK: tt.func private @{{.*}}print_num_ctas
3044+
# CHECK-NEXT: arith.constant 1 : i32
3045+
print_num_ctas()
3046+
3047+
30333048
def test_mismatch_shape_and_layout_rank():
30343049

30353050
@gluon.jit

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
full,
5353
gather,
5454
num_warps,
55+
num_ctas,
5556
histogram,
5657
inline_asm_elementwise,
5758
join,

python/triton/experimental/gluon/language/_core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
"static_range",
7575
"tuple",
7676
"tuple_type",
77+
"num_ctas",
7778
]
7879

7980
T = TypeVar("T")
@@ -525,6 +526,14 @@ def num_warps(_semantic=None, _generator=None):
525526
return _semantic.num_warps(_generator)
526527

527528

529+
@builtin
530+
def num_ctas(_semantic=None):
531+
"""
532+
Returns the number of CTAs in the current kernel
533+
"""
534+
return _semantic.num_ctas()
535+
536+
528537
@builtin
529538
def thread_barrier(_semantic=None):
530539
"""

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,9 @@ def warp_specialize(self, functions_and_args, worker_num_warps: Sequence[int], w
551551
return
552552
return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
553553

554+
def num_ctas(self):
555+
return ttgl.constexpr(self.builder.options.num_ctas)
556+
554557
def num_warps(self, generator):
555558
if generator.caller_context is not None:
556559
assert isinstance(generator.caller_context, GluonCallerContext)

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2645,3 +2645,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
26452645
tt.return
26462646
}
26472647
}
2648+
2649+
// -----
2650+
2651+
// We had a bug where DotOp lowering treated any input where shape[1] == 1 as an
2652+
// outer product and rejected it. This was incorrect in 3D tensors, since
2653+
// the dimension to look at would have been shape[2].
2654+
2655+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [32, 1, 1], instrShape = [1, 16, 8]}>
2656+
#dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>
2657+
#dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>
2658+
2659+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
2660+
// CHECK-LABEL: batched_dot_3d
2661+
tt.func public @batched_dot_3d(
2662+
%arg0: tensor<32x1x32xf16, #dot_operand_a>,
2663+
%arg1: tensor<32x32x32xf16, #dot_operand_b>
2664+
) {
2665+
%cst = arith.constant dense<0.000000e+00> : tensor<32x1x32xf32, #mma>
2666+
// CHECK: llvm.inline_asm
2667+
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
2668+
%result = tt.dot %arg0, %arg1, %cst, inputPrecision = tf32 :
2669+
tensor<32x1x32xf16, #dot_operand_a> * tensor<32x32x32xf16, #dot_operand_b> -> tensor<32x1x32xf32, #mma>
2670+
tt.return
2671+
}
2672+
}

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
7474

7575
// -----
7676

77+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [16, 2], instrShape = [16, 256, 16]}>
78+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
79+
#smem = #ttg.shared_memory
80+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
81+
// CHECK-LABEL: @warp_group_dot_bf16_32_warps
82+
tt.func @warp_group_dot_bf16_32_warps(
83+
%a: !ttg.memdesc<256x128xbf16, #shared, #smem>,
84+
%b: !ttg.memdesc<128x512xbf16, #shared, #smem>,
85+
%acc: tensor<256x512xf32, #mma>) {
86+
%res = ttng.warp_group_dot %a, %b, %acc {inputPrecision = 0 : i32, isAsync = true} :
87+
!ttg.memdesc<256x128xbf16, #shared, #smem> * !ttg.memdesc<128x512xbf16, #shared, #smem> -> tensor<256x512xf32, #mma>
88+
// CHECK: nvgpu.wgmma {{.*}} k = 16 : i32, layoutA = 1 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32}
89+
tt.return
90+
}
91+
}
92+
93+
// -----
94+
7795
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
7896
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
7997
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>

test/NVWS/aref-tmem-insertion.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -558,15 +558,17 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
558558
// CHECK-NEXT: aref.create
559559
// CHECK-NEXT: aref.put.enter
560560
%result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
561-
scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 {
562-
%0 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
561+
%5 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %token) -> (!ttg.async.token) : i32 {
562+
%0 = ttg.local_alloc %arg1 {ttg.partition = array<i32: 0>} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
563563
%1 = tt.descriptor_load %arg2[%arg3, %arg3] {ttg.partition = array<i32: 2>} : !tt.tensordesc<tensor<64x128xf16, #shared>> -> tensor<64x128xf16, #blocked1>
564564
%2 = arith.addf %1, %1 {ttg.partition = array<i32: 0>} : tensor<64x128xf16, #blocked1>
565565
%3 = ttg.local_alloc %2 {ttg.partition = array<i32: 0>} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem>
566566
// CHECK: aref.buffer
567-
%4 = ttng.tc_gen5_mma %0, %3, %result[%token], %true, %true {ttg.partition = array<i32: 1>} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
568-
} {tt.num_stages = 2 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 18 : i32}
567+
%4 = ttng.tc_gen5_mma %0, %3, %result[%arg4], %true, %true {ttg.partition = array<i32: 1>} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
568+
scf.yield %4 : !ttg.async.token
569+
} {ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 1>], tt.num_stages = 2 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 18 : i32}
569570
// CHECK: aref.put.exit
571+
ttng.tmem_load %result[%5] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
570572
tt.return
571573
}
572574

test/NVWS/assign_stage_phase.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,3 +674,55 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
674674
tt.return
675675
}
676676
}
677+
678+
// -----
679+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
680+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
681+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
682+
#smem = #ttg.shared_memory
683+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
684+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
685+
// CHECK-LABEL: @for_loop_control_operand_ppg
686+
tt.func @for_loop_control_operand_ppg(%lb: i32, %ub: i32, %step: i32, %ptr0: !tt.ptr<i32>) {
687+
%true = arith.constant true
688+
%arefBuf = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>
689+
%aref = nvws.aref.create %arefBuf : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>
690+
%_0, %tok = nvws.aref.put.enter %aref : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
691+
// CHECK: put.enter
692+
// CHECK-NEXT: [[RET:%.*]]:5 = scf.for
693+
%tok0 = scf.for %iv0 = %lb to %ub step %step iter_args(%tok1 = %tok) -> (!ttg.async.token) : i32 {
694+
// CHECK-NEXT: tt.addptr {{.*}} {ttg.partition = array<i32: 0, 1, 2>}
695+
// CHECK-NEXT: tt.load {{.*}} {ttg.partition = array<i32: 0, 1, 2>}
696+
// CHECK-NEXT: "lb1"({{.*}}) {ttg.partition = array<i32: 0, 1, 2>}
697+
// CHECK-NEXT: "step1"({{.*}}) {ttg.partition = array<i32: 0, 1, 2>}
698+
%ptrub = tt.addptr %ptr0, %iv0 {ttg.partition = array<i32: 1, 2>} : !tt.ptr<i32>, i32
699+
%ub1 = tt.load %ptrub {ttg.partition = array<i32: 1, 2>} : !tt.ptr<i32>
700+
%lb1 = "lb1"(%iv0) {ttg.partition = array<i32: 1, 2>} : (i32) -> i32
701+
%step1 = "step1"(%iv0) {ttg.partition = array<i32: 1, 2>} : (i32) -> i32
702+
// CHECK-NEXT: [[RET1:%.*]]:3 = scf.for
703+
%tok5 = scf.for %iv = %lb1 to %ub1 step %step1 iter_args(%tok2 = %tok1) -> (!ttg.async.token) : i32 {
704+
%sA = "load1"(%iv) {ttg.partition = array<i32: 1>} : (i32) -> !ttg.memdesc<128x64xf32, #shared, #smem>
705+
%sB = "load2"(%iv) {ttg.partition = array<i32: 1>} : (i32) -> !ttg.memdesc<64x128xf32, #shared, #smem>
706+
%buf = nvws.aref.buffer %aref, %tok2 {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
707+
ttng.tc_gen5_mma %sA, %sB, %buf, %true, %true {ttg.partition = array<i32: 2>} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x128xf32, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
708+
scf.yield {ttg.partition = array<i32: 1, 2>} %tok2 : !ttg.async.token
709+
} {ttg.partition = array<i32: 1, 2>, ttg.partition.outputs = [array<i32: 2>]}
710+
// CHECK: scf.yield
711+
// CHECK-NEXT: {ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 2>, array<i32: 0, 2>, array<i32: 2>]}
712+
// CHECK-NEXT: nvws.aref.put.exit {{.*}}[[[RET1]]#1]
713+
nvws.aref.put.exit %aref, %tok5 [#nvws.async_op<tc5mma>] {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
714+
%_1, %token_2 = nvws.aref.get.enter %aref {ttg.partition = array<i32: 1>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
715+
nvws.aref.get.exit %aref, %token_2 [#nvws.async_op<none>] {ttg.partition = array<i32: 1>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
716+
%buf1, %tok6 = nvws.aref.put.enter %aref {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
717+
// CHECK: aref.put.enter
718+
// CHECK-NEXT: scf.yield
719+
scf.yield {ttg.partition = array<i32: 1, 2>} %tok6 : !ttg.async.token
720+
// CHECK-NEXT: {tt.warp_specialize, ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 2>, array<i32: 0, 2>, array<i32: 2>, array<i32: 0, 1>, array<i32: 0, 1>]}
721+
} {tt.warp_specialize, ttg.partition = array<i32: 1, 2>, ttg.partition.outputs = [array<i32: 2>]}
722+
// CHECK-NEXT: aref.put.exit {{.*}}[[[RET]]#1]
723+
nvws.aref.put.exit %aref, %tok0 [#nvws.async_op<tc5mma>] : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
724+
%_2, %token_2 = nvws.aref.get.enter %aref : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
725+
nvws.aref.get.exit %aref, %token_2 [#nvws.async_op<none>] : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
726+
tt.return
727+
}
728+
}

0 commit comments

Comments
 (0)