Skip to content

Commit b37bd6b

Browse files
authored
[Frontend][Gluon] Pass contextual num_warps down call graph (#7529)
Warp-specialization changes the number of contextual warps for certain sections of code. This information needs to be propagated down the callgraph as well. This PR adds an optional `caller_context` that can be set and propagated to callees. `ttgl.warp_specialize` uses this to pass down the number of worker warps.
1 parent 98b1409 commit b37bd6b

File tree

3 files changed

+66
-6
lines changed

3 files changed

+66
-6
lines changed

python/test/gluon/test_frontend.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,47 @@ def test_warp_specialize():
419419
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, (pair, c, e), [warp_specialize_worker1], [4], [48])
420420

421421

422+
@gluon.jit
423+
def ws_body(num_warps: ttgl.constexpr):
424+
anchor(ttgl.arange(0, 128, layout=ttgl.BlockedLayout([1], [32], [num_warps], [0])))
425+
426+
427+
@gluon.jit
428+
def ws_test_default():
429+
ws_body(4)
430+
431+
432+
@gluon.jit
433+
def ws_test_worker0():
434+
ws_body(2)
435+
436+
437+
@gluon.jit
438+
def ws_test_worker1():
439+
ws_body(1)
440+
441+
442+
@filecheck_test
443+
@gluon.jit
444+
def test_num_warps_caller_context():
445+
# CHECK-DAG: [[BLOCKED_NW4:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
446+
# CHECK-DAG: [[BLOCKED_NW2:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
447+
# CHECK-DAG: [[BLOCKED_NW1:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
448+
449+
# CHECK: func private @{{.*}}ws_test_default{{.*}}() attributes {noinline = false}
450+
# CHECK: func private @{{.*}}ws_body{{.*}}() attributes {noinline = false}
451+
# CHECK: func private @{{.*}}anchor{{.*}}(%arg0: tensor<128xi32, [[BLOCKED_NW4]]>) attributes {noinline = false}
452+
453+
# CHECK: func private @{{.*}}ws_test_worker0{{.*}}_NW2() attributes {noinline = false, "ttg.num-warps" = 2 : i32}
454+
# CHECK: func private @{{.*}}ws_body{{.*}}_NW2"() attributes {noinline = false, "ttg.num-warps" = 2 : i32}
455+
# CHECK: func private @{{.*}}anchor{{.*}}_NW2(%arg0: tensor<128xi32, [[BLOCKED_NW2]]>) attributes {noinline = false, "ttg.num-warps" = 2 : i32}
456+
457+
# CHECK: func private @{{.*}}ws_test_worker1{{.*}}_NW1() attributes {noinline = false, "ttg.num-warps" = 1 : i32}
458+
# CHECK: func private @{{.*}}ws_body{{.*}}_NW1"() attributes {noinline = false, "ttg.num-warps" = 1 : i32}
459+
# CHECK: func private @{{.*}}anchor{{.*}}_NW1(%arg0: tensor<128xi32, [[BLOCKED_NW1]]>) attributes {noinline = false, "ttg.num-warps" = 1 : i32}
460+
ttgl.warp_specialize((), ws_test_default, (), [ws_test_worker0, ws_test_worker1], [2, 1], [80, 80])
461+
462+
422463
@gluon.jit
423464
def mbarrier_kernel():
424465
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())

python/triton/compiler/code_generator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def check_identifier_legality(name, type):
2828
return name
2929

3030

31-
def mangle_fn(name, arg_tys, constants):
31+
def mangle_fn(name, arg_tys, constants, caller_context):
3232
# doesn't mangle ret type, which must be a function of arg tys
3333
mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
3434
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
@@ -37,6 +37,8 @@ def mangle_fn(name, arg_tys, constants):
3737
# [ and ] are not allowed in LLVM identifiers
3838
mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
3939
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
40+
if caller_context is not None:
41+
ret += caller_context.mangle()
4042
return ret
4143

4244

@@ -293,7 +295,7 @@ class CodeGenerator(ast.NodeVisitor):
293295

294296
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
295297
module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
296-
file_name: Optional[str] = None, begin_line=0):
298+
caller_context=None, file_name: Optional[str] = None, begin_line=0):
297299
self.context = context
298300
if jit_fn.is_gluon():
299301
from triton.experimental.gluon.language._semantic import GluonSemantic
@@ -339,6 +341,7 @@ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunctio
339341
self.is_kernel = is_kernel
340342
self.cur_node = None
341343
self.noinline = noinline
344+
self.caller_context = caller_context
342345
self.scf_stack = []
343346
self.ret_type = None
344347
# SSA-construction
@@ -570,6 +573,8 @@ def visit_FunctionDef(self, node):
570573
self.module.push_back(self.fn)
571574
entry = self.fn.add_entry_block()
572575
arg_values = self.prototype.deserialize(self.fn)
576+
if self.caller_context is not None:
577+
self.caller_context.initialize_callee(self.fn, self.builder)
573578
# bind arguments to symbols
574579
for arg_name, arg_value in zip(arg_names, arg_values):
575580
self.set_value(arg_name, arg_value)
@@ -1190,7 +1195,7 @@ def visit_Assert(self, node) -> Any:
11901195
msg = self.visit(node.msg) if node.msg is not None else ""
11911196
return language.core.device_assert(test, msg, _semantic=self.semantic)
11921197

1193-
def call_JitFunction(self, fn: JITFunction, args, kwargs):
1198+
def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
11941199
args = inspect.getcallargs(fn.fn, *args, **kwargs)
11951200
args = [args[name] for name in fn.arg_names]
11961201
for i, arg in enumerate(args):
@@ -1201,7 +1206,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
12011206
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
12021207
args_val = [get_iterable_path(args, path) for path in args_path]
12031208
# mangle
1204-
fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst)
1209+
caller_context = caller_context or self.caller_context
1210+
fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
12051211
# generate function def if necessary
12061212
if not self.module.has_function(fn_name):
12071213
# If the callee is not set, we use the same debug setting as the caller
@@ -1216,7 +1222,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
12161222
function_name=fn_name, function_types=self.function_ret_types,
12171223
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
12181224
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
1219-
module_map=self.builder.module_map)
1225+
module_map=self.builder.module_map, caller_context=caller_context)
12201226
try:
12211227
generator.visit(fn.parse())
12221228
except Exception as e:

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
1414
raise category(msg_fn())
1515

1616

17+
class GluonCallerContext:
18+
19+
def __init__(self, num_warps: int):
20+
self.num_warps = num_warps
21+
22+
def mangle(self):
23+
return f"_NW{self.num_warps}"
24+
25+
def initialize_callee(self, fn, builder):
26+
fn.set_attr("ttg.num-warps", builder.get_int32_attr(self.num_warps))
27+
28+
1729
class GluonSemantic(TritonSemantic[TensorTy]):
1830
tensor = ttgl.tensor
1931
lang = ttgl
@@ -319,10 +331,11 @@ def warp_specialize(self, default_args, default_partition, worker_args, worker_p
319331
partitions_op = builder.create_warp_specialize_partitions(num_partitions)
320332
arg_types = [arg.get_type() for arg in mlir_args]
321333
for i in range(num_partitions):
334+
caller_context = GluonCallerContext(num_warps=worker_num_warps[i])
322335
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
323336
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
324337
block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
325-
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
338+
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}, caller_context=caller_context)
326339
builder.create_warp_return()
327340

328341
builder.set_insertion_point_after(ws_op.get_operation())

0 commit comments

Comments
 (0)