Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,43 @@ def _bitsandbytes_available():
BITSANDBYTES_AVAILABLE = _bitsandbytes_available()


# TODO This change should be handled properly in the future
def _pytorch_removed_args_tensor_mask() -> bool:
"""Check if PyTorch removed args_tensor_mask from autograd_function_apply.

PyTorch removed args_tensor_mask in https://github.com/pytorch/pytorch/pull/166788
(commit 5cf15aef144fd03379a5796e37698eee5e4575b8, Dec 12, 2025).

Returns True if args_tensor_mask is NO LONGER accepted (new PyTorch).
"""
import re

version = torch.__version__
# Nightly versions look like: 2.6.0.dev20251212+cpu
match = re.search(r"\.dev(\d{8})", version)
if match:
nightly_date = int(match.group(1))
# args_tensor_mask removed around Dec 12, 2025
return nightly_date >= 20251212

# Lightning AI nightly builds have version strings like: 2.10.0a0+git62c80e7
# We conservatively check against the base version (the portion before "+")
# because we do not know whether the problematic commit removing args_tensor_mask
# is before or after 2.10.0a0+git62c80e7. Therefore, we return True (masked as removed)
# if we are on version 2.10.0a0 or later.
base_version = packaging.version.parse(version.split("+")[0])
if base_version >= packaging.version.parse("2.10.0a0"):
return True

return False


xfail_if_args_tensor_mask_removed = pytest.mark.xfail(
_pytorch_removed_args_tensor_mask(),
reason="PyTorch >= 2.10.0a0+git62c80e7 or nightly >= 20251212 removed args_tensor_mask from autograd_function_apply (PR #166788)",
)


def version_between(version: str, *, min_ver: str | None = None, max_ver: str | None = None):
v = packaging.version.parse(version)
if min_ver is not None and v < packaging.version.parse(min_ver):
Expand Down
3 changes: 3 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
DynamoThunderExecutor,
IS_WINDOWS,
requiresCUDA,
xfail_if_args_tensor_mask_removed,
)
from thunder.tests.make_tensor import make_tensor
from thunder.dynamo.report import (
Expand Down Expand Up @@ -369,6 +370,7 @@ def func(x):
strict=True,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
xfail_if_args_tensor_mask_removed,
),
)
def test_splitter_autograd_function(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
Expand Down Expand Up @@ -1631,6 +1633,7 @@ def compile(self, fn, **kwargs):


@requiresCUDA
@xfail_if_args_tensor_mask_removed
def test_autograd_function_fx_report(tmp_path):
class Sin(torch.autograd.Function):
@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import thunder

from thunder.tests.framework import requiresCUDA, IS_WINDOWS
from thunder.tests.framework import requiresCUDA, IS_WINDOWS, xfail_if_args_tensor_mask_removed
from thunder.core.options import CACHE_OPTIONS
import thunder.core.prims as prims
from thunder import pytorch_executor, nvfuser_executor
Expand Down Expand Up @@ -1252,6 +1252,7 @@ def f(x):


@pytest.mark.filterwarnings("ignore:Please use torch.vmap")
@xfail_if_args_tensor_mask_removed
def test_autograd_function_apply():
# see https://github.com/Lightning-AI/lightning-thunder/issues/1248#issuecomment-2388655917
# for why `torch.foo` instead of `torch.Tensor.foo`
Expand Down Expand Up @@ -1328,6 +1329,7 @@ def my_sin_with_wrong_backward(x):
gradcheck(jitted, (x,))


@xfail_if_args_tensor_mask_removed
def test_autograd_function_apply_with_no_grad():
# This case is using `torch` operations
def forward(_, x):
Expand Down
2 changes: 2 additions & 0 deletions thunder/tests/test_update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TorchCompileExecutor,
nvFuserExecutor,
requiresCUDA,
xfail_if_args_tensor_mask_removed,
)
from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place

Expand Down Expand Up @@ -477,6 +478,7 @@ def f(a):

@instantiate(
dtypes=(dtypes.float32,),
decorators=(xfail_if_args_tensor_mask_removed,),
)
def test_higher_order_inplace_alias_update(executor, device, dtype):
torch_dtype = dtypes.to_torch_dtype(dtype)
Expand Down
Loading