Skip to content

Commit e090c31

Browse files
authored
[FRONTEND] Remove incorrect and misleading error message from the return-in-while check (#7551)
We actually don't check `return`s in functions being called within `while` or `for`. ``` def _visit_function(self, fn) -> bool: # No need to check within the function as it won't cause an early return. # If the function itself has unstructured control flow we may not be able to inline it causing poor performance, # we should check for this and emit a warning. return False ```
1 parent d6b0238 commit e090c31

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import triton
22
import triton.language as tl
33
from triton._filecheck import filecheck_test, run_filecheck_test
4+
from triton.compiler.errors import CompilationError
5+
import pytest
46

57
# ===-----------------------------------------------------------------------===#
68
# Unit Tests
@@ -528,3 +530,36 @@ def test_specialized_recursion():
528530

529531
# CHECK: func {{.*}}recursive_reduce__i32S4S
530532
# CHECK-COUNT-2: call {{.*}}recursive_reduce__i32S2S
533+
534+
535+
@triton.jit
536+
def trivial_return():
537+
return
538+
539+
540+
@filecheck_test
541+
@triton.jit
542+
def test_call_in_while():
543+
# CHECK-LABEL: test_call_in_while
544+
i = 0
545+
while i < 10:
546+
if i == 5:
547+
trivial_return()
548+
else:
549+
trivial_return()
550+
551+
552+
def test_return_in_while():
553+
554+
@triton.jit
555+
def kernel():
556+
i = 0
557+
while i < 10:
558+
if i == 5:
559+
return
560+
i += 1
561+
562+
with pytest.raises(CompilationError) as e:
563+
kernel.warmup(grid=(1, ))
564+
565+
assert "Cannot have `return` statements inside `while` or `for` statements in triton" in str(e.value)

python/triton/compiler/code_generator.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ def _visit_stmts(self, body) -> bool:
149149
return any(self.visit(s) for s in body)
150150

151151
def _visit_function(self, fn) -> bool:
152-
# no need to check within the function as it won't cause an early return.
153-
# If the function itself has unstructured control flow we may not be able to inline it causing poor performance.
154-
# We should check for this and fail or emit a warning.
152+
# No need to check within the function as it won't cause an early return.
153+
# If the function itself has unstructured control flow we may not be able to inline it causing poor performance,
154+
# we should check for this and emit a warning.
155155
return False
156156

157157
def generic_visit(self, node) -> bool:
@@ -857,13 +857,10 @@ def visit_If(self, node):
857857
% ast.unparse(node.test))
858858
cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
859859
cond = cond.to(language.int1, _semantic=self.semantic)
860-
contains_return = ContainsReturnChecker(self.gscope).visit(node)
861-
if contains_return:
860+
if ContainsReturnChecker(self.gscope).visit(node):
862861
if self.scf_stack:
863862
raise self._unsupported(
864-
node, "Cannot have `return` statements inside `while` or `for` statements in triton "
865-
"(note that this also applies to `return` statements that are inside functions "
866-
"transitively called from within `while`/`for` statements)")
863+
node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
867864
self.visit_if_top_level(cond, node)
868865
else:
869866
self.visit_if_scf(cond, node)

0 commit comments

Comments
 (0)