|
2 | 2 | import torch |
3 | 3 | import triton.language as tl |
4 | 4 | import triton |
| 5 | +import sys |
| 6 | +import subprocess |
| 7 | +import os |
5 | 8 |
|
6 | 9 |
|
7 | 10 | @pytest.mark.parametrize('cond', [True, False]) |
|
10 | 13 | @pytest.mark.parametrize('env_var', [True, False]) |
11 | 14 | @pytest.mark.parametrize('jit_flag', [True, False]) |
12 | 15 | @pytest.mark.forked |
13 | | -def test_device_assert(monkeypatch, cond, mask, opt_flag, env_var, jit_flag, device): |
14 | | - monkeypatch.setenv("TRITON_DEBUG", str(int(env_var))) |
15 | | - triton.knobs.refresh_knobs() |
16 | | - torch.zeros([1], dtype=torch.int32, device=device) |
17 | | - |
18 | | - @triton.jit(debug=jit_flag) |
19 | | - def _kernel(COND: tl.constexpr, MASK: tl.constexpr): |
20 | | - tl.device_assert(COND, 'test', mask=MASK) |
| 16 | +def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device): |
| 17 | + """Temporary subprocess solution due to: |
| 18 | + https://github.com/pytorch/pytorch/issues/142135""" |
21 | 19 |
|
22 | 20 | is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag) |
23 | 21 |
|
24 | | - kwargs = {} |
25 | | - if opt_flag is not None: |
26 | | - kwargs["debug"] = opt_flag |
27 | | - |
28 | | - if not cond and is_debug and mask is not False: |
29 | | - with pytest.raises(RuntimeError): |
30 | | - _kernel[(1, )](cond, mask, **kwargs) |
31 | | - getattr(torch, device).synchronize() |
32 | | - return |
33 | | - |
34 | | - _kernel[(1, )](cond, mask, **kwargs) |
35 | | - getattr(torch, device).synchronize() |
| 22 | + should_fail = not cond and is_debug and mask is not False |
| 23 | + kernel_file = os.path.join(os.path.dirname(__file__), "test_debug_kernels.py") |
| 24 | + mask_str = "None" if mask is None else str(mask) |
| 25 | + opt_flag_str = "None" if opt_flag is None else str(opt_flag) |
| 26 | + |
| 27 | + result = subprocess.run([ |
| 28 | + sys.executable, kernel_file, "device_assert", |
| 29 | + str(cond), mask_str, opt_flag_str, |
| 30 | + str(jit_flag), device, |
| 31 | + str(env_var) |
| 32 | + ], capture_output=True, text=True) |
| 33 | + |
| 34 | + if should_fail: |
| 35 | + abort_or_runtime_error = ( |
| 36 | + result.returncode == 1 or # RuntimeError |
| 37 | + result.returncode == -6 # SIGABRT |
| 38 | + ) |
| 39 | + assert abort_or_runtime_error, ( |
| 40 | + f"Expected runtime error or abort signal but got unexpected exit code {result.returncode}. " |
| 41 | + f"stdout: {result.stdout}, stderr: {result.stderr}") |
| 42 | + else: |
| 43 | + assert result.returncode == 0, (f"Expected success but got unexpected exit code {result.returncode}. " |
| 44 | + f"stdout: {result.stdout}, stderr: {result.stderr}") |
36 | 45 |
|
37 | 46 |
|
38 | 47 | def test_device_assert_barrier(monkeypatch, device): |
@@ -70,10 +79,14 @@ def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref |
70 | 79 | y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device) |
71 | 80 | z = torch.empty_like(x) |
72 | 81 | if should_overflow and debug: |
73 | | - with pytest.raises(RuntimeError) as exc_info: |
| 82 | + # with pytest.raises(RuntimeError) as exc_info: |
| 83 | + try: |
74 | 84 | tri_func[(1, )](x, y, z, debug=debug) |
75 | 85 | getattr(torch, device).synchronize() |
76 | | - assert "device-side assert" in str(exc_info.value) |
| 86 | + except RuntimeError as e: |
| 87 | + assert True |
| 88 | + assert "device-side assert" in str(e) #str(exc_info.value) |
| 89 | + assert False |
77 | 90 | else: |
78 | 91 | tri_func[(1, )](x, y, z, debug=debug) |
79 | 92 | getattr(torch, device).synchronize() |
|
0 commit comments