Skip to content

Commit 700ba8a

Browse files
committed
Fix KeyError for args_tensor_mask removed in PyTorch PR #166788
1 parent fb989d4 commit 700ba8a

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

thunder/core/jit_ext.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,15 @@ def _generate_random_str_id() -> str:
940940
length = 5
941941
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(length))
942942

943-
args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"])
943+
# NOTE: `args_tensor_mask` was removed in PyTorch PR #166788.
944+
# https://github.com/pytorch/pytorch/pull/166788/
945+
# For backwards compatibility with older PyTorch versions, we check if it exists.
946+
# When not present, we assume all fwd_args are tensors.
947+
if "args_tensor_mask" in fwd_kwargs:
948+
args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"])
949+
else:
950+
# For new PyTorch versions without args_tensor_mask, treat all args as potential tensors
951+
args_tensor_mask = tuple(isinstance(unwrap(arg), TensorProxy) for arg in fwd_args)
944952
# TODO(crcrpar): Think about making use of `non_differentiable_idx`
945953
# note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087
946954
# non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")

0 commit comments

Comments
 (0)