Skip to content

Commit 71a3240

Browse files
committed
split assert conditions
1 parent 5fe75f1 commit 71a3240

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

python/test/unit/test_debug.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,21 @@ def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device):
2424
mask_str = "None" if mask is None else str(mask)
2525
opt_flag_str = "None" if opt_flag is None else str(opt_flag)
2626

27-
result = subprocess.run([
28-
sys.executable, kernel_file, "device_assert",
29-
str(cond), mask_str, opt_flag_str,
30-
str(jit_flag), device,
31-
str(env_var)
32-
], capture_output=True, text=True)
27+
env = os.environ.copy()
28+
env["TRITON_DEBUG"] = str(int(env_var))
29+
30+
result = subprocess.run(
31+
[sys.executable, kernel_file, "device_assert",
32+
str(cond), mask_str, opt_flag_str,
33+
str(jit_flag), device], capture_output=True, text=True, env=env)
3334

3435
if should_fail:
35-
abort_or_runtime_error = (
36-
result.returncode == 1 or # RuntimeError
37-
result.returncode == -6 # SIGABRT
38-
)
39-
assert abort_or_runtime_error, (
40-
f"Expected runtime error or abort signal but got unexpected exit code {result.returncode}. "
41-
f"stdout: {result.stdout}, stderr: {result.stderr}")
36+
if device == 'xpu':
37+
assert result.returncode == -6, (f"Expected SIGABRT but got exit code {result.returncode}. "
38+
f"stdout: {result.stdout}, stderr: {result.stderr}")
39+
else:
40+
assert result.returncode == 1, (f"Expected runtime error but got unexpected exit code {result.returncode}. "
41+
f"stdout: {result.stdout}, stderr: {result.stderr}")
4242
else:
4343
assert result.returncode == 0, (f"Expected success but got unexpected exit code {result.returncode}. "
4444
f"stdout: {result.stdout}, stderr: {result.stderr}")

python/test/unit/test_debug_kernels.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import triton
77
import triton.language as tl
88
import sys
9-
import os
109

1110

1211
def run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device):
@@ -44,9 +43,6 @@ def parse_bool_or_none(arg_str):
4443
opt_flag = parse_bool_or_none(sys.argv[4])
4544
jit_flag = sys.argv[5] == "True"
4645
device = sys.argv[6]
47-
env_var = sys.argv[7] == "True"
48-
49-
os.environ["TRITON_DEBUG"] = str(int(env_var))
5046
triton.knobs.refresh_knobs()
5147
exit_code = run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device)
5248
sys.exit(exit_code)

0 commit comments

Comments
 (0)