Skip to content

Commit 272188c

Browse files
authored
[Frontend] Fix detection of liveouts in conditionals (#7318)
cc @Anstow
1 parent ca37374 commit 272188c

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def test_constexpr_generator():
374374
generator(lhs)
375375

376376

377+
@tl.constexpr_function
377378
def Box(T):
378379

379380
@tl.core._aggregate
@@ -401,3 +402,23 @@ def kernel():
401402
anchor(value)
402403

403404
run_filecheck_test(kernel)
405+
406+
407+
@filecheck_test
408+
@triton.jit
409+
def test_modify_if_livein():
410+
# CHECK-LABEL: test_modify_if_livein
411+
none_livein = None # noqa: F841
412+
413+
# CHECK: [[LOOP_OUT:%.*]] = scf.for {{.*}} iter_args([[BOX:%.*]] = %true)
414+
# CHECK: [[LIVEOUT:%.*]] = scf.if [[BOX]]
415+
# CHECK: yield %false
416+
# CHECK: else
417+
# CHECK: yield [[BOX]]
418+
# CHECK: yield [[LIVEOUT]]
419+
# CHECK: call @{{.*}}anchor{{.*}}([[LOOP_OUT]])
420+
box = Box(tl.tensor)(tl.core.to_tensor(True))
421+
for i in range(10):
422+
if box.value:
423+
box.value = False
424+
anchor(box.value)

python/triton/compiler/code_generator.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -719,35 +719,40 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block):
719719
self.visit_compound_statement(node.body)
720720
then_block = self.builder.get_insertion_block()
721721
then_defs = self.local_defs.copy()
722+
then_vals = self.lscope.copy()
722723
# else block
723724
else_defs = {}
725+
else_vals = liveins.copy()
724726
if node.orelse:
725727
self.builder.set_insertion_point_to_start(else_block)
726728
self.lscope = liveins.copy()
727729
self.local_defs = {}
728730
self.visit_compound_statement(node.orelse)
729731
else_defs = self.local_defs.copy()
730732
else_block = self.builder.get_insertion_block()
733+
else_vals = self.lscope.copy()
731734

732735
# update block arguments
733736
names = []
734737
# variables in livein whose value is updated in `if`
735-
for name in liveins:
738+
for name, value in liveins.items():
739+
# livein variable changed value in either then or else
740+
if not _is_triton_value(value):
741+
continue
742+
then_handles = flatten_values_to_ir([then_vals[name]])
743+
else_handles = flatten_values_to_ir([else_vals[name]])
744+
if then_handles == else_handles:
745+
continue
746+
names.append(name)
747+
then_defs[name] = then_vals[name]
748+
else_defs[name] = else_vals[name]
736749
# check type
737750
for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
738-
if name in defs:
739-
type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721
740-
assert type_equal and defs[name].type == liveins[name].type, \
741-
f'initial value for `{name}` is of type {liveins[name]}, '\
742-
f'but the {block_name} block redefines it as {defs[name]}'
743-
if name in then_defs or name in else_defs:
744-
names.append(name)
745-
# variable defined in then but not in else
746-
if name in then_defs and name not in else_defs:
747-
else_defs[name] = liveins[name]
748-
# variable defined in else but not in then
749-
if name in else_defs and name not in then_defs:
750-
then_defs[name] = liveins[name]
751+
type_equal = type(defs[name]) == type(value) # noqa: E721
752+
assert type_equal and defs[name].type == value.type, \
753+
f'initial value for `{name}` is of type {value}, '\
754+
f'but the {block_name} block redefines it as {defs[name]}'
755+
751756
# variables that are both in then and else but not in liveins
752757
# TODO: could probably be cleaned up
753758
for name in sorted(then_defs.keys() & else_defs.keys()):

0 commit comments

Comments
 (0)