Skip to content

Commit f902558

Browse files
authored
[Gluon] Add ttgl.get_num_warps metafunction (#8133)
This is a function which returns a constexpr of the current contextual number of warps. This returns the value passed as `num_warps=` when called from a regular function, but this also works inside warp specialized regions and returns the number of warps in the current partition. cc @aeng-openai
1 parent 625c8cb commit f902558

File tree

4 files changed

+41
-0
lines changed

4 files changed

+41
-0
lines changed

python/test/gluon/test_frontend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,3 +2375,29 @@ def test_layout_zeros():
23752375
# CHECK: #blocked = #ttg.blocked
23762376
# CHECK: arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked>
23772377
ttgl.zeros([128], ttgl.float32, layout=ttgl.BlockedLayout([1], [32], [4], [0]))
2378+
2379+
2380+
@gluon.jit
2381+
def print_num_warps():
2382+
num_warps: ttgl.constexpr = ttgl.num_warps()
2383+
print("num_warps", num_warps)
2384+
2385+
2386+
@filecheck_test
2387+
@gluon.jit
2388+
def test_get_num_warps():
2389+
# CHECK-LABEL: test_get_num_warps
2390+
# CHECK: tt.func private @{{.*}}print_num_warps
2391+
# CHECK-NEXT arith.constant 4 : i32
2392+
2393+
# CHECK: tt.func private @{{.*}}print_num_warps{{.*}}NW1
2394+
# CHECK-NEXT arith.constant 1 : i32
2395+
2396+
# CHECK: tt.func private @{{.*}}print_num_warps{{.*}}NW2
2397+
# CHECK-NEXT arith.constant 2 : i32
2398+
2399+
# CHECK: tt.func private @{{.*}}print_num_warps{{.*}}NW8
2400+
# CHECK-NEXT arith.constant 8 : i32
2401+
print_num_warps()
2402+
ttgl.warp_specialize((), print_num_warps, (), [print_num_warps, print_num_warps, print_num_warps], [1, 2, 8],
2403+
[24, 24, 24])

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
expand_dims,
4848
full,
4949
gather,
50+
num_warps,
5051
histogram,
5152
inline_asm_elementwise,
5253
join,

python/triton/experimental/gluon/language/_core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,14 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti
501501
worker_num_regs, _generator)
502502

503503

504+
@builtin
505+
def num_warps(_semantic=None, _generator=None):
506+
"""
507+
Returns the number of warps that execute the current context, including in warp-specialized regions.
508+
"""
509+
return _semantic.num_warps(_generator)
510+
511+
504512
@builtin
505513
def thread_barrier(_semantic=None):
506514
"""

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,9 @@ def warp_specialize(self, default_args, default_partition, worker_args, worker_p
427427
if default_results is None:
428428
return
429429
return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))
430+
431+
def num_warps(self, generator):
432+
if generator.caller_context is not None:
433+
assert isinstance(generator.caller_context, GluonCallerContext)
434+
return ttgl.constexpr(generator.caller_context.num_warps)
435+
return ttgl.constexpr(self.builder.options.num_warps)

0 commit comments

Comments
 (0)