Skip to content

Commit a0d0335

Browse files
Added xfail for expected failures (#2805)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fb989d4 commit a0d0335

File tree

4 files changed

+46
-1
lines changed

4 files changed

+46
-1
lines changed

thunder/tests/framework.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,44 @@ def _bitsandbytes_available():
7777
BITSANDBYTES_AVAILABLE = _bitsandbytes_available()
7878

7979

80+
# TODO This change should be handled properly, this is a temporary fix to allow the CI to progress.
81+
# See https://github.com/Lightning-AI/lightning-thunder/issues/2807
82+
def _pytorch_removed_args_tensor_mask() -> bool:
83+
"""Check if PyTorch removed args_tensor_mask from autograd_function_apply.
84+
85+
PyTorch removed args_tensor_mask in https://github.com/pytorch/pytorch/pull/166788
86+
(commit 5cf15aef144fd03379a5796e37698eee5e4575b8, Dec 12, 2025).
87+
88+
Returns True if args_tensor_mask is NO LONGER accepted (new PyTorch).
89+
"""
90+
import re
91+
92+
version = torch.__version__
93+
# Nightly versions look like: 2.6.0.dev20251212+cpu
94+
match = re.search(r"\.dev(\d{8})", version)
95+
if match:
96+
nightly_date = int(match.group(1))
97+
# args_tensor_mask removed around Dec 12, 2025
98+
return nightly_date >= 20251212
99+
100+
# Lightning AI nightly builds have version strings like: 2.10.0a0+git62c80e7
101+
# We conservatively check against the base version (the portion before "+")
102+
# because we do not know whether the problematic commit removing args_tensor_mask
103+
# is before or after 2.10.0a0+git62c80e7. Therefore, we return True (masked as removed)
104+
# if we are on version 2.10.0a0 or later.
105+
base_version = packaging.version.parse(version.split("+")[0])
106+
if base_version >= packaging.version.parse("2.10.0a0"):
107+
return True
108+
109+
return False
110+
111+
112+
xfail_if_args_tensor_mask_removed = pytest.mark.xfail(
113+
_pytorch_removed_args_tensor_mask(),
114+
reason="PyTorch >= 2.10.0a0+git62c80e7 or nightly >= 20251212 removed args_tensor_mask from autograd_function_apply (PR #166788)",
115+
)
116+
117+
80118
def version_between(version: str, *, min_ver: str | None = None, max_ver: str | None = None):
81119
v = packaging.version.parse(version)
82120
if min_ver is not None and v < packaging.version.parse(min_ver):

thunder/tests/test_dynamo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DynamoThunderExecutor,
3333
IS_WINDOWS,
3434
requiresCUDA,
35+
xfail_if_args_tensor_mask_removed,
3536
)
3637
from thunder.tests.make_tensor import make_tensor
3738
from thunder.dynamo.report import (
@@ -369,6 +370,7 @@ def func(x):
369370
strict=True,
370371
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
371372
),
373+
xfail_if_args_tensor_mask_removed,
372374
),
373375
)
374376
def test_splitter_autograd_function(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
@@ -1631,6 +1633,7 @@ def compile(self, fn, **kwargs):
16311633

16321634

16331635
@requiresCUDA
1636+
@xfail_if_args_tensor_mask_removed
16341637
def test_autograd_function_fx_report(tmp_path):
16351638
class Sin(torch.autograd.Function):
16361639
@staticmethod

thunder/tests/test_jit_general.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import thunder
1515

16-
from thunder.tests.framework import requiresCUDA, IS_WINDOWS
16+
from thunder.tests.framework import requiresCUDA, IS_WINDOWS, xfail_if_args_tensor_mask_removed
1717
from thunder.core.options import CACHE_OPTIONS
1818
import thunder.core.prims as prims
1919
from thunder import pytorch_executor, nvfuser_executor
@@ -1252,6 +1252,7 @@ def f(x):
12521252

12531253

12541254
@pytest.mark.filterwarnings("ignore:Please use torch.vmap")
1255+
@xfail_if_args_tensor_mask_removed
12551256
def test_autograd_function_apply():
12561257
# see https://github.com/Lightning-AI/lightning-thunder/issues/1248#issuecomment-2388655917
12571258
# for why `torch.foo` instead of `torch.Tensor.foo`
@@ -1328,6 +1329,7 @@ def my_sin_with_wrong_backward(x):
13281329
gradcheck(jitted, (x,))
13291330

13301331

1332+
@xfail_if_args_tensor_mask_removed
13311333
def test_autograd_function_apply_with_no_grad():
13321334
# This case is using `torch` operations
13331335
def forward(_, x):

thunder/tests/test_update_aliases.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TorchCompileExecutor,
2222
nvFuserExecutor,
2323
requiresCUDA,
24+
xfail_if_args_tensor_mask_removed,
2425
)
2526
from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place
2627

@@ -477,6 +478,7 @@ def f(a):
477478

478479
@instantiate(
479480
dtypes=(dtypes.float32,),
481+
decorators=(xfail_if_args_tensor_mask_removed,),
480482
)
481483
def test_higher_order_inplace_alias_update(executor, device, dtype):
482484
torch_dtype = dtypes.to_torch_dtype(dtype)

0 commit comments

Comments
 (0)