1010from triton ._internal_testing import is_interpreter , is_cuda , is_hip , is_hip_mi300 , is_hip_mi200 , is_xpu
1111
1212
13+ def format_exception (type , value , tb ):
14+ list_msg = traceback .format_exception (type , value , tb , chain = False )
15+ return "\n " .join (list_msg )
16+
17+
1318def test_err_undefined_variable ():
1419
1520 @triton .jit
@@ -20,7 +25,9 @@ def kernel():
2025 triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
2126
2227 try :
23- assert "is not defined" in str (e .value ), "error should mention the undefined variable"
28+ err_msg = format_exception (e .type , value = e .value , tb = e .tb )
29+ assert "is not defined" in err_msg , "error should mention the undefined variable"
30+ assert "code_generator.py" not in err_msg
2431 except AssertionError as assertion_err :
2532 raise assertion_err from e .value
2633
@@ -35,7 +42,9 @@ def kernel():
3542 triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
3643
3744 try :
38- assert "at 2:4:" in str (e .value ), "error should point to the 0"
45+ err_msg = format_exception (e .type , value = e .value , tb = e .tb )
46+ assert "at 2:4:" in err_msg , "error should point to the 0"
47+ assert "code_generator.py" not in err_msg
3948 except AssertionError as assertion_err :
4049 raise assertion_err from e .value
4150
@@ -52,8 +61,11 @@ def kernel():
5261 try :
5362 assert isinstance (e .value , CompileTimeAssertionFailure )
5463 assert e .value .__cause__ is None
55- assert "at 2:4:" in str (e .value ), "error should point to the static_assert call"
56- assert "<source unavailable>" not in str (e .value )
64+ err_msg = format_exception (e .type , value = e .value , tb = e .tb )
65+ print (err_msg )
66+ assert "at 2:4:" in err_msg , "error should point to the static_assert call"
67+ assert "<source unavailable>" not in err_msg
68+ assert "code_generator.py" not in err_msg
5769 except AssertionError as assertion_err :
5870 raise assertion_err from e .value
5971
@@ -70,8 +82,10 @@ def kernel():
7082
7183 try :
7284 assert e .value .__cause__ is None
73- assert "at 2:4:" in str (e .value ), "error should point to the `not`"
74- assert "<source unavailable>" not in str (e .value )
85+ err_msg = format_exception (e .type , value = e .value , tb = e .tb )
86+ assert "at 2:4:" in err_msg , "error should point to the `not`"
87+ assert "<source unavailable>" not in err_msg
88+ assert "code_generator.py" not in err_msg
7589 except AssertionError as assertion_err :
7690 raise assertion_err from e .value
7791
@@ -86,8 +100,10 @@ def kernel():
86100 triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
87101
88102 try :
89- assert "at 2:4:" in str (e .value ), "error should point to the 1.0"
90- assert "<source unavailable>" not in str (e .value )
103+ err_msg = format_exception (e .type , value = e .value , tb = e .tb )
104+ assert "at 2:4:" in err_msg , "error should point to the 1.0"
105+ assert "<source unavailable>" not in err_msg
106+ assert "code_generator.py" not in err_msg
91107 except AssertionError as assertion_err :
92108 raise assertion_err from e .value
93109
@@ -110,13 +126,16 @@ def kernel():
110126 triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
111127
112128 try :
113- inner = e .value .__cause__
114- outer = e .value
115- assert "at 2:4:" in str (inner ), "error should point to xyz"
116- assert "<source unavailable>" not in str (inner )
117-
118- assert "at 3:4" in str (outer ), "error should point to the nested_call"
119- assert "<source unavailable>" not in str (outer )
129+ inner_exc = e .value .__cause__
130+ inner = format_exception (inner_exc .__class__ , inner_exc , inner_exc .__traceback__ )
131+ assert "at 2:4:" in inner , "error should point to xyz"
132+ assert "<source unavailable>" not in inner
133+ assert "code_generator.py" not in inner
134+
135+ outer = format_exception (e .type , value = e .value , tb = e .tb )
136+ assert "at 3:4" in outer , "error should point to the nested_call"
137+ assert "<source unavailable>" not in outer
138+ assert "code_generator.py" not in outer
120139 except AssertionError as assertion_err :
121140 raise assertion_err from e .value
122141
@@ -133,13 +152,15 @@ def kernel():
133152 triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
134153
135154 try :
136- inner = e .value .__cause__
137- outer = e .value
138- assert f"{ os .sep } core.py" in '\n ' .join (traceback .format_tb (
139- inner .__traceback__ )), "error should point inside core.py"
140-
141- assert "at 2:4:" in str (outer ), "error should point to expand_dims call"
142- assert "<source unavailable>" not in str (outer )
155+ inner_exc = e .value .__cause__
156+ inner = format_exception (inner_exc .__class__ , inner_exc , inner_exc .__traceback__ )
157+ assert f"{ os .sep } core.py" in inner , "error should point inside core.py"
158+ assert "code_generator.py" not in inner
159+
160+ outer = format_exception (e .type , value = e .value , tb = e .tb )
161+ assert "at 2:4:" in outer , "error should point to expand_dims call"
162+ assert "<source unavailable>" not in outer
163+ assert "code_generator.py" not in outer
143164 except AssertionError as assertion_err :
144165 raise assertion_err from e .value
145166
0 commit comments