@@ -24,21 +24,21 @@ def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device):
2424 mask_str = "None" if mask is None else str (mask )
2525 opt_flag_str = "None" if opt_flag is None else str (opt_flag )
2626
27- result = subprocess .run ([
28- sys .executable , kernel_file , "device_assert" ,
29- str (cond ), mask_str , opt_flag_str ,
30- str (jit_flag ), device ,
31- str (env_var )
32- ], capture_output = True , text = True )
27+ env = os .environ .copy ()
28+ env ["TRITON_DEBUG" ] = str (int (env_var ))
29+
30+ result = subprocess .run (
31+ [sys .executable , kernel_file , "device_assert" ,
32+ str (cond ), mask_str , opt_flag_str ,
33+ str (jit_flag ), device ], capture_output = True , text = True , env = env )
3334
3435 if should_fail :
35- abort_or_runtime_error = (
36- result .returncode == 1 or # RuntimeError
37- result .returncode == - 6 # SIGABRT
38- )
39- assert abort_or_runtime_error , (
40- f"Expected runtime error or abort signal but got unexpected exit code { result .returncode } . "
41- f"stdout: { result .stdout } , stderr: { result .stderr } " )
36+ if device == 'xpu' :
37+ assert result .returncode == - 6 , (f"Expected SIGABRT but got exit code { result .returncode } . "
38+ f"stdout: { result .stdout } , stderr: { result .stderr } " )
39+ else :
40+ assert result .returncode == 1 , (f"Expected runtime error but got unexpected exit code { result .returncode } . "
41+ f"stdout: { result .stdout } , stderr: { result .stderr } " )
4242 else :
4343 assert result .returncode == 0 , (f"Expected success but got unexpected exit code { result .returncode } . "
4444 f"stdout: { result .stdout } , stderr: { result .stderr } " )
0 commit comments