Skip to content

Commit ba5ac26

Browse files
authored
Improve detection of loop carries in triton frontend (#7200)
Before this change, loop carries weren't correctly detected when `@builtin` or `@core.extern` function modified their arguments (which is a particular issue for member functions). This improves the detection of loop carries in for and while loops to handle these cases. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent d4432f6 commit ba5ac26

File tree

2 files changed

+87
-70
lines changed

2 files changed

+87
-70
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,44 @@ def test_aggregate_initializers():
103103
anchor(value)
104104

105105

106+
@filecheck_test
107+
@triton.jit
108+
def test_aggregate_modification_in_for_loop():
109+
# CHECK-LABEL: test_aggregate_modification_in_for_loop
110+
value = TypeWithBuiltinInitializer()
111+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
112+
for i in range(0, 2):
113+
# CHECK: [[RET:%.*]] = scf.for
114+
# CHECK-SAME: iter_args([[ITER:%.*]] = [[RANGE]])
115+
value.modify(tl.arange(4, 8))
116+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
117+
# CHECK: yield [[RANGE]]
118+
119+
anchor(value)
120+
# CHECK: call @{{.*}}anchor{{.*}}([[RET]])
121+
122+
123+
@filecheck_test
124+
@triton.jit
125+
def test_aggregate_modification_in_while_loop():
126+
# CHECK-LABEL: test_aggregate_modification_in_while_loop
127+
value = TypeWithBuiltinInitializer()
128+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
129+
i = 0
130+
# CHECK: [[C0:%.*]] = arith.constant 0 :
131+
while i < 1:
132+
# CHECK: [[RET:%.*]]:2 = scf.while ([[ITER:%.*]] = [[RANGE]], [[IV:%.*]] = [[C0]])
133+
# CHECK: do
134+
i = 1
135+
# CHECK: [[C1:%.*]] = arith.constant 1 :
136+
value.modify(tl.arange(4, 8))
137+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
138+
# CHECK: yield [[RANGE]], [[C1]]
139+
140+
anchor(value)
141+
# CHECK: call @{{.*}}anchor{{.*}}([[RET]]#0)
142+
143+
106144
@triton.jit
107145
def forward(arg):
108146
return arg

python/triton/compiler/code_generator.py

Lines changed: 49 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,47 @@ def _set_insertion_point_and_loc(self, ip, loc):
435435
self.builder.restore_insertion_point(ip)
436436
self.builder.set_loc(loc)
437437

438+
def _find_carries(self, node, liveins):
439+
# We must extract the handles before the value is editted in the loop
440+
livehandles = {name: flatten_values_to_ir([v]) for name, v in liveins.items() if _is_triton_value(v)}
441+
# create loop body block
442+
block = self.builder.create_block()
443+
self.builder.set_insertion_point_to_start(block)
444+
# dry visit loop body
445+
self.scf_stack.append(node)
446+
self.visit_compound_statement(node.body)
447+
self.scf_stack.pop()
448+
block.erase()
449+
450+
# If a variable (name) has changed value within the loop, then it's
451+
# a loop-carried variable. (The new and old value must be of the
452+
# same type)
453+
init_tys = []
454+
init_handles = []
455+
names = []
456+
457+
for name, live_val in liveins.items():
458+
if _is_triton_value(live_val):
459+
loop_val = self.lscope[name]
460+
assert type(live_val) is type(loop_val), f'Loop carried variable {name} changed type'
461+
462+
live_handles = livehandles[name]
463+
loop_handles = flatten_values_to_ir([loop_val])
464+
if live_handles != loop_handles:
465+
self._verify_loop_carried_variable(name, loop_val, live_val)
466+
467+
names.append(name)
468+
init_tys.append(live_val.type)
469+
init_handles.extend(live_handles)
470+
else:
471+
assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
472+
473+
# reset local scope to not pick up local defs from the dry run.
474+
self.lscope = liveins.copy()
475+
self.local_defs = {}
476+
477+
return names, init_handles, init_tys
478+
438479
#
439480
# AST visitor
440481
#
@@ -918,8 +959,8 @@ def visit_UnaryOp(self, node):
918959
}
919960

920961
def _verify_loop_carried_variable(self, name, loop_val, live_val):
921-
assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop'
922-
assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop'
962+
assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
963+
assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
923964
assert type(loop_val) is type(live_val), f'Loop carried variable {name} changed type'
924965
assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
925966
f'Loop-carried variable {name} has initial type {live_val.type} '\
@@ -931,33 +972,9 @@ def visit_While(self, node):
931972
liveins, insert_block = sr
932973
ip, last_loc = self._get_insertion_point_and_loc()
933974

934-
# loop body (the after region)
935-
# loop_block = self.builder.create_block()
936-
dummy = self.builder.create_block()
937-
self.builder.set_insertion_point_to_start(dummy)
938-
self.scf_stack.append(node)
939-
self.visit_compound_statement(node.body)
940-
self.scf_stack.pop()
941-
loop_defs = self.local_defs
942-
dummy.erase()
943-
944-
# collect loop-carried values
945-
names = []
946-
init_args = []
947-
for name in loop_defs:
948-
if name in liveins:
949-
# We should not def new constexpr
950-
loop_val = loop_defs[name]
951-
live_val = liveins[name]
952-
self._verify_loop_carried_variable(name, loop_val, live_val)
953-
954-
# these are loop-carried values
955-
names.append(name)
956-
init_args.append(live_val)
975+
names, init_handles, init_fe_tys = self._find_carries(node, liveins)
957976

958-
init_handles = flatten_values_to_ir(init_args)
959977
init_tys = [h.get_type() for h in init_handles]
960-
init_fe_tys = [a.type for a in init_args]
961978
self._set_insertion_point_and_loc(ip, last_loc)
962979
while_op = self.builder.create_while_op(init_tys, init_handles)
963980
# merge the condition region
@@ -985,13 +1002,9 @@ def visit_While(self, node):
9851002
self.scf_stack.append(node)
9861003
self.visit_compound_statement(node.body)
9871004
self.scf_stack.pop()
988-
loop_defs = self.local_defs
989-
yields = []
990-
for name in loop_defs:
991-
if name in liveins:
992-
loop_defs[name]._flatten_ir(yields)
9931005

994-
self.builder.create_yield_op(yields)
1006+
yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
1007+
self.builder.create_yield_op(yield_handles)
9951008

9961009
# WhileOp defines new values, update the symbol table (lscope, local_defs)
9971010
result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
@@ -1097,34 +1110,10 @@ def visit_For(self, node):
10971110
liveins, insert_block = sr
10981111
ip, last_loc = self._get_insertion_point_and_loc()
10991112

1100-
# create loop body block
1101-
block = self.builder.create_block()
1102-
self.builder.set_insertion_point_to_start(block)
1103-
# dry visit loop body
1104-
self.scf_stack.append(node)
1105-
self.visit_compound_statement(node.body)
1106-
self.scf_stack.pop()
1107-
block.erase()
1108-
1109-
# If a variable (name) is defined in both its parent & itself, then it's
1110-
# a loop-carried variable. (They must be of the same type)
1111-
init_args = []
1112-
yields = []
1113-
names = []
1114-
for name in self.local_defs:
1115-
if name in liveins:
1116-
loop_val = self.local_defs[name]
1117-
live_val = liveins[name]
1118-
self._verify_loop_carried_variable(name, loop_val, live_val)
1119-
1120-
names.append(name)
1121-
init_args.append(live_val)
1122-
yields.append(loop_val)
1113+
names, init_handles, init_tys = self._find_carries(node, liveins)
11231114

11241115
# create ForOp
11251116
self._set_insertion_point_and_loc(ip, last_loc)
1126-
init_handles = flatten_values_to_ir(init_args)
1127-
init_tys = [v.type for v in init_args]
11281117
for_op = self.builder.create_for_op(lb, ub, step, init_handles)
11291118
if _unwrap_if_constexpr(num_stages) is not None:
11301119
for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
@@ -1140,26 +1129,16 @@ def visit_For(self, node):
11401129
self.scf_stack.append(node)
11411130
for_op_body = for_op.get_body(0)
11421131
self.builder.set_insertion_point_to_start(for_op_body)
1143-
# reset local scope to not pick up local defs from the previous dry run.
1144-
self.lscope = liveins.copy()
1145-
self.local_defs = {}
11461132
block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
11471133
block_args = unflatten_ir_values(block_handles, init_tys)
11481134
for name, val in zip(names, block_args):
11491135
self.set_value(name, val)
11501136
self.visit_compound_statement(node.body)
11511137
self.scf_stack.pop()
1152-
yields = []
1153-
for name in self.local_defs:
1154-
if name in liveins:
1155-
local = self.local_defs[name]
1156-
if isinstance(local, constexpr):
1157-
local = self.semantic.to_tensor(local)
1158-
yields.append(local)
1138+
yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
11591139

11601140
# create YieldOp
1161-
if len(yields) > 0:
1162-
yield_handles = flatten_values_to_ir(yields)
1141+
if len(yield_handles) > 0:
11631142
self.builder.create_yield_op(yield_handles)
11641143
for_op_region = for_op_body.get_parent()
11651144
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"

0 commit comments

Comments
 (0)