Skip to content

Commit 0be6469

Browse files
authored
[GLUON] Implement num_ctas (#8602)
1 parent 4c6349d commit 0be6469

File tree

4 files changed

+28
-0
lines changed

4 files changed

+28
-0
lines changed

python/test/gluon/test_frontend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,6 +3006,12 @@ def print_num_warps():
30063006
print("num_warps", num_warps)
30073007

30083008

3009+
@gluon.jit
3010+
def print_num_ctas():
3011+
num_ctas: ttgl.constexpr = ttgl.num_ctas()
3012+
print("num_ctas", num_ctas)
3013+
3014+
30093015
@filecheck_test
30103016
@gluon.jit
30113017
def test_get_num_warps():
@@ -3030,6 +3036,15 @@ def test_get_num_warps():
30303036
], [1, 2, 8], [24, 24, 24])
30313037

30323038

3039+
@filecheck_test
3040+
@gluon.jit
3041+
def test_num_ctas():
3042+
# CHECK-LABEL: test_num_ctas
3043+
# CHECK: tt.func private @{{.*}}print_num_ctas
3044+
# CHECK-NEXT: arith.constant 1 : i32
3045+
print_num_ctas()
3046+
3047+
30333048
def test_mismatch_shape_and_layout_rank():
30343049

30353050
@gluon.jit

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
full,
5353
gather,
5454
num_warps,
55+
num_ctas,
5556
histogram,
5657
inline_asm_elementwise,
5758
join,

python/triton/experimental/gluon/language/_core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
"static_range",
7575
"tuple",
7676
"tuple_type",
77+
"num_ctas",
7778
]
7879

7980
T = TypeVar("T")
@@ -525,6 +526,14 @@ def num_warps(_semantic=None, _generator=None):
525526
return _semantic.num_warps(_generator)
526527

527528

529+
@builtin
530+
def num_ctas(_semantic=None):
531+
"""
532+
Returns the number of CTAs in the current kernel
533+
"""
534+
return _semantic.num_ctas()
535+
536+
528537
@builtin
529538
def thread_barrier(_semantic=None):
530539
"""

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,9 @@ def warp_specialize(self, functions_and_args, worker_num_warps: Sequence[int], w
551551
return
552552
return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
553553

554+
def num_ctas(self):
555+
return ttgl.constexpr(self.builder.options.num_ctas)
556+
554557
def num_warps(self, generator):
555558
if generator.caller_context is not None:
556559
assert isinstance(generator.caller_context, GluonCallerContext)

0 commit comments

Comments
 (0)