Skip to content

Commit e854fcd

Browse files
authored
[FRONTEND] Restore error traceback filtering (#5731)
The `filter_traceback` call was commented out during the tuple PR. This just restores it and adds a check in the relevant tests.
1 parent bc4675a commit e854fcd

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300, is_hip_mi200
1111

1212

13+
def format_exception(type, value, tb):
14+
list_msg = traceback.format_exception(type, value, tb, chain=False)
15+
return "\n".join(list_msg)
16+
17+
1318
def test_err_undefined_variable():
1419

1520
@triton.jit
@@ -20,7 +25,9 @@ def kernel():
2025
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
2126

2227
try:
23-
assert "is not defined" in str(e.value), "error should mention the undefined variable"
28+
err_msg = format_exception(e.type, value=e.value, tb=e.tb)
29+
assert "is not defined" in err_msg, "error should mention the undefined variable"
30+
assert "code_generator.py" not in err_msg
2431
except AssertionError as assertion_err:
2532
raise assertion_err from e.value
2633

@@ -35,7 +42,9 @@ def kernel():
3542
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
3643

3744
try:
38-
assert "at 2:4:" in str(e.value), "error should point to the 0"
45+
err_msg = format_exception(e.type, value=e.value, tb=e.tb)
46+
assert "at 2:4:" in err_msg, "error should point to the 0"
47+
assert "code_generator.py" not in err_msg
3948
except AssertionError as assertion_err:
4049
raise assertion_err from e.value
4150

@@ -52,8 +61,11 @@ def kernel():
5261
try:
5362
assert isinstance(e.value, CompileTimeAssertionFailure)
5463
assert e.value.__cause__ is None
55-
assert "at 2:4:" in str(e.value), "error should point to the static_assert call"
56-
assert "<source unavailable>" not in str(e.value)
64+
err_msg = format_exception(e.type, value=e.value, tb=e.tb)
65+
print(err_msg)
66+
assert "at 2:4:" in err_msg, "error should point to the static_assert call"
67+
assert "<source unavailable>" not in err_msg
68+
assert "code_generator.py" not in err_msg
5769
except AssertionError as assertion_err:
5870
raise assertion_err from e.value
5971

@@ -70,8 +82,10 @@ def kernel():
7082

7183
try:
7284
assert e.value.__cause__ is None
73-
assert "at 2:4:" in str(e.value), "error should point to the `not`"
74-
assert "<source unavailable>" not in str(e.value)
85+
err_msg = format_exception(e.type, value=e.value, tb=e.tb)
86+
assert "at 2:4:" in err_msg, "error should point to the `not`"
87+
assert "<source unavailable>" not in err_msg
88+
assert "code_generator.py" not in err_msg
7589
except AssertionError as assertion_err:
7690
raise assertion_err from e.value
7791

@@ -86,8 +100,10 @@ def kernel():
86100
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
87101

88102
try:
89-
assert "at 2:4:" in str(e.value), "error should point to the 1.0"
90-
assert "<source unavailable>" not in str(e.value)
103+
err_msg = format_exception(e.type, value=e.value, tb=e.tb)
104+
assert "at 2:4:" in err_msg, "error should point to the 1.0"
105+
assert "<source unavailable>" not in err_msg
106+
assert "code_generator.py" not in err_msg
91107
except AssertionError as assertion_err:
92108
raise assertion_err from e.value
93109

@@ -110,13 +126,16 @@ def kernel():
110126
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
111127

112128
try:
113-
inner = e.value.__cause__
114-
outer = e.value
115-
assert "at 2:4:" in str(inner), "error should point to xyz"
116-
assert "<source unavailable>" not in str(inner)
117-
118-
assert "at 3:4" in str(outer), "error should point to the nested_call"
119-
assert "<source unavailable>" not in str(outer)
129+
inner_exc = e.value.__cause__
130+
inner = format_exception(inner_exc.__class__, inner_exc, inner_exc.__traceback__)
131+
assert "at 2:4:" in inner, "error should point to xyz"
132+
assert "<source unavailable>" not in inner
133+
assert "code_generator.py" not in inner
134+
135+
outer = format_exception(e.type, value=e.value, tb=e.tb)
136+
assert "at 3:4" in outer, "error should point to the nested_call"
137+
assert "<source unavailable>" not in outer
138+
assert "code_generator.py" not in outer
120139
except AssertionError as assertion_err:
121140
raise assertion_err from e.value
122141

@@ -133,13 +152,15 @@ def kernel():
133152
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
134153

135154
try:
136-
inner = e.value.__cause__
137-
outer = e.value
138-
assert f"{os.sep}core.py" in '\n'.join(traceback.format_tb(
139-
inner.__traceback__)), "error should point inside core.py"
140-
141-
assert "at 2:4:" in str(outer), "error should point to expand_dims call"
142-
assert "<source unavailable>" not in str(outer)
155+
inner_exc = e.value.__cause__
156+
inner = format_exception(inner_exc.__class__, inner_exc, inner_exc.__traceback__)
157+
assert f"{os.sep}core.py" in inner, "error should point inside core.py"
158+
assert "code_generator.py" not in inner
159+
160+
outer = format_exception(e.type, value=e.value, tb=e.tb)
161+
assert "at 2:4:" in outer, "error should point to expand_dims call"
162+
assert "<source unavailable>" not in outer
163+
assert "code_generator.py" not in outer
143164
except AssertionError as assertion_err:
144165
raise assertion_err from e.value
145166

python/triton/compiler/compiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,11 @@ def compile(src, target=None, options=None):
271271

272272
codegen_fns = backend.get_codegen_implementation(options)
273273
module_map = backend.get_module_map()
274-
# try:
275-
module = src.make_ir(options, codegen_fns, module_map, context)
276-
# except Exception as e:
277-
# filter_traceback(e)
278-
# raise
274+
try:
275+
module = src.make_ir(options, codegen_fns, module_map, context)
276+
except Exception as e:
277+
filter_traceback(e)
278+
raise
279279
use_ir_loc = os.environ.get("USE_IR_LOC", None)
280280
for ext, compile_ir in list(stages.items())[first_stage:]:
281281
next_module = compile_ir(module, metadata)

0 commit comments

Comments
 (0)