From 3de296d2d473e18868125fdeea10c17f719a8b6b Mon Sep 17 00:00:00 2001 From: dev-tomek Date: Tue, 28 Oct 2025 11:41:23 +0000 Subject: [PATCH 1/3] [TEST_DEBUG] Enable test_device_assert --- python/test/unit/test_debug.py | 57 ++++++++++++++++---------- python/test/unit/test_debug_kernels.py | 56 +++++++++++++++++++++++++ scripts/skiplist/a770/debug.txt | 1 - scripts/skiplist/arl-h/debug.txt | 1 - scripts/skiplist/arl-s/debug.txt | 1 - scripts/skiplist/default/debug.txt | 1 - scripts/skiplist/lts/debug.txt | 1 - scripts/skiplist/mtl/debug.txt | 1 - scripts/skiplist/xe2/debug.txt | 1 - 9 files changed, 91 insertions(+), 29 deletions(-) create mode 100644 python/test/unit/test_debug_kernels.py diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index e771da39a1..8d6d67569f 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -2,6 +2,9 @@ import torch import triton.language as tl import triton +import sys +import subprocess +import os @pytest.mark.parametrize('cond', [True, False]) @@ -10,29 +13,35 @@ @pytest.mark.parametrize('env_var', [True, False]) @pytest.mark.parametrize('jit_flag', [True, False]) @pytest.mark.forked -def test_device_assert(monkeypatch, cond, mask, opt_flag, env_var, jit_flag, device): - monkeypatch.setenv("TRITON_DEBUG", str(int(env_var))) - triton.knobs.refresh_knobs() - torch.zeros([1], dtype=torch.int32, device=device) - - @triton.jit(debug=jit_flag) - def _kernel(COND: tl.constexpr, MASK: tl.constexpr): - tl.device_assert(COND, 'test', mask=MASK) +def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device): + """Temporary subprocess solution due to: + https://github.com/pytorch/pytorch/issues/142135""" is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag) - kwargs = {} - if opt_flag is not None: - kwargs["debug"] = opt_flag - - if not cond and is_debug and mask is not False: - with pytest.raises(RuntimeError): - _kernel[(1, )](cond, mask, **kwargs) - getattr(torch, device).synchronize() - return - - _kernel[(1, )](cond, mask, **kwargs) - getattr(torch, device).synchronize() + should_fail = not cond and is_debug and mask is not False + kernel_file = os.path.join(os.path.dirname(__file__), "test_debug_kernels.py") + mask_str = "None" if mask is None else str(mask) + opt_flag_str = "None" if opt_flag is None else str(opt_flag) + + result = subprocess.run([ + sys.executable, kernel_file, "device_assert", + str(cond), mask_str, opt_flag_str, + str(jit_flag), device, + str(env_var) + ], capture_output=True, text=True) + + if should_fail: + abort_or_runtime_error = ( + result.returncode == 1 or # RuntimeError + result.returncode == -6 # SIGABRT + ) + assert abort_or_runtime_error, ( + f"Expected runtime error or abort signal but got unexpected exit code {result.returncode}. " + f"stdout: {result.stdout}, stderr: {result.stderr}") + else: + assert result.returncode == 0, (f"Expected success but got unexpected exit code {result.returncode}. " + f"stdout: {result.stdout}, stderr: {result.stderr}") def test_device_assert_barrier(monkeypatch, device): @@ -70,10 +79,14 @@ def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device) z = torch.empty_like(x) if should_overflow and debug: - with pytest.raises(RuntimeError) as exc_info: + # with pytest.raises(RuntimeError) as exc_info: + try: tri_func[(1, )](x, y, z, debug=debug) getattr(torch, device).synchronize() - assert "device-side assert" in str(exc_info.value) + except RuntimeError as e: + assert True + assert "device-side assert" in str(e) #str(exc_info.value) + assert False else: tri_func[(1, )](x, y, z, debug=debug) getattr(torch, device).synchronize() diff --git a/python/test/unit/test_debug_kernels.py b/python/test/unit/test_debug_kernels.py new file mode 100644 index 0000000000..df8e146382 --- /dev/null +++ b/python/test/unit/test_debug_kernels.py @@ -0,0 +1,56 @@ +""" +Helper module containing Triton kernels for test_debug.py. +These kernels are separated so they can be called from subprocesses. +""" +import torch +import triton +import triton.language as tl +import sys +import os + + +def run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device): + + @triton.jit(debug=jit_flag) + def _kernel(COND: tl.constexpr, MASK: tl.constexpr): + tl.device_assert(COND, 'test', mask=MASK) + + kwargs = {} + if opt_flag is not None: + kwargs["debug"] = opt_flag + + try: + _kernel[(1, )](cond, mask, **kwargs) + getattr(torch, device).synchronize() + return 0 + except RuntimeError: + return 1 + except Exception as e: + print(f"Unexpected error: {type(e).__name__}: {e}") + return 2 + + +if __name__ == "__main__": + + def parse_bool_or_none(arg_str): + if arg_str == "None": + return None + return arg_str == "True" + + test_type = sys.argv[1] + if test_type == "device_assert": + cond = sys.argv[2] == "True" + mask = parse_bool_or_none(sys.argv[3]) + opt_flag = parse_bool_or_none(sys.argv[4]) + jit_flag = sys.argv[5] == "True" + device = sys.argv[6] + env_var = sys.argv[7] == "True" + + os.environ["TRITON_DEBUG"] = str(int(env_var)) + triton.knobs.refresh_knobs() + exit_code = run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device) + sys.exit(exit_code) + + else: + print(f"Unknown test type: {test_type}") + sys.exit(3) diff --git a/scripts/skiplist/a770/debug.txt b/scripts/skiplist/a770/debug.txt index b40a78489a..6915e5520f 100644 --- a/scripts/skiplist/a770/debug.txt +++ b/scripts/skiplist/a770/debug.txt @@ -1,5 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 -python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] diff --git a/scripts/skiplist/arl-h/debug.txt b/scripts/skiplist/arl-h/debug.txt index b40a78489a..6915e5520f 100644 --- a/scripts/skiplist/arl-h/debug.txt +++ b/scripts/skiplist/arl-h/debug.txt @@ -1,5 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 -python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] diff --git a/scripts/skiplist/arl-s/debug.txt b/scripts/skiplist/arl-s/debug.txt index b40a78489a..6915e5520f 100644 --- a/scripts/skiplist/arl-s/debug.txt +++ b/scripts/skiplist/arl-s/debug.txt @@ -1,5 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 -python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] diff --git a/scripts/skiplist/default/debug.txt b/scripts/skiplist/default/debug.txt index b40a78489a..6915e5520f 100644 --- a/scripts/skiplist/default/debug.txt +++ b/scripts/skiplist/default/debug.txt @@ -1,5 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 -python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] diff --git a/scripts/skiplist/lts/debug.txt b/scripts/skiplist/lts/debug.txt index b40a78489a..6915e5520f 100644 --- a/scripts/skiplist/lts/debug.txt +++ b/scripts/skiplist/lts/debug.txt @@ -1,5 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 -python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] diff --git a/scripts/skiplist/mtl/debug.txt b/scripts/skiplist/mtl/debug.txt index b40a78489a..6915e5520f 100644 --- a/scripts/skiplist/mtl/debug.txt +++ b/scripts/skiplist/mtl/debug.txt @@ -1,5 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 -python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] diff --git a/scripts/skiplist/xe2/debug.txt b/scripts/skiplist/xe2/debug.txt index b40a78489a..6915e5520f 100644 --- a/scripts/skiplist/xe2/debug.txt +++ b/scripts/skiplist/xe2/debug.txt @@ -1,5 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 -python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] From 4df011deddaaa05e75d1ed0c27fb62184333234c Mon Sep 17 00:00:00 2001 From: dev-tomek Date: Tue, 28 Oct 2025 11:45:54 +0000 Subject: [PATCH 2/3] delete leftovers --- python/test/unit/test_debug.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index 8d6d67569f..5689a0b32e 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -79,14 +79,10 @@ def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device) z = torch.empty_like(x) if should_overflow and debug: - # with pytest.raises(RuntimeError) as exc_info: - try: + with pytest.raises(RuntimeError) as exc_info: tri_func[(1, )](x, y, z, debug=debug) getattr(torch, device).synchronize() - except RuntimeError as e: - assert True - assert "device-side assert" in str(e) #str(exc_info.value) - assert False + assert "device-side assert" in str(exc_info.value) else: tri_func[(1, )](x, y, z, debug=debug) getattr(torch, device).synchronize() From 71a324091f3e33acf60cd4c2412f66f67dcd3243 Mon Sep 17 00:00:00 2001 From: dev-tomek Date: Tue, 4 Nov 2025 11:06:49 +0000 Subject: [PATCH 3/3] split assert conditions --- python/test/unit/test_debug.py | 26 +++++++++++++------------- python/test/unit/test_debug_kernels.py | 4 ---- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index 5689a0b32e..81fa5e8a3d 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -24,21 +24,21 @@ def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device): mask_str = "None" if mask is None else str(mask) opt_flag_str = "None" if opt_flag is None else str(opt_flag) - result = subprocess.run([ - sys.executable, kernel_file, "device_assert", - str(cond), mask_str, opt_flag_str, - str(jit_flag), device, - str(env_var) - ], capture_output=True, text=True) + env = os.environ.copy() + env["TRITON_DEBUG"] = str(int(env_var)) + + result = subprocess.run( + [sys.executable, kernel_file, "device_assert", + str(cond), mask_str, opt_flag_str, + str(jit_flag), device], capture_output=True, text=True, env=env) if should_fail: - abort_or_runtime_error = ( - result.returncode == 1 or # RuntimeError - result.returncode == -6 # SIGABRT - ) - assert abort_or_runtime_error, ( - f"Expected runtime error or abort signal but got unexpected exit code {result.returncode}. " - f"stdout: {result.stdout}, stderr: {result.stderr}") + if device == 'xpu': + assert result.returncode == -6, (f"Expected SIGABRT but got exit code {result.returncode}. " + f"stdout: {result.stdout}, stderr: {result.stderr}") + else: + assert result.returncode == 1, (f"Expected runtime error but got unexpected exit code {result.returncode}. " + f"stdout: {result.stdout}, stderr: {result.stderr}") else: assert result.returncode == 0, (f"Expected success but got unexpected exit code {result.returncode}. " f"stdout: {result.stdout}, stderr: {result.stderr}") diff --git a/python/test/unit/test_debug_kernels.py b/python/test/unit/test_debug_kernels.py index df8e146382..e4d57a1630 100644 --- a/python/test/unit/test_debug_kernels.py +++ b/python/test/unit/test_debug_kernels.py @@ -6,7 +6,6 @@ import triton import triton.language as tl import sys -import os def run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device): @@ -44,9 +43,6 @@ def parse_bool_or_none(arg_str): opt_flag = parse_bool_or_none(sys.argv[4]) jit_flag = sys.argv[5] == "True" device = sys.argv[6] - env_var = sys.argv[7] == "True" - - os.environ["TRITON_DEBUG"] = str(int(env_var)) triton.knobs.refresh_knobs() exit_code = run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device) sys.exit(exit_code)