Skip to content

Fix KeyError for args_tensor_mask removed in PyTorch nightly#2804

Closed
IvanYashchuk wants to merge 1 commit intomainfrom
fix-2803
Closed

Fix KeyError for args_tensor_mask removed in PyTorch nightly#2804
IvanYashchuk wants to merge 1 commit intomainfrom
fix-2803

Conversation

@IvanYashchuk
Copy link
Collaborator

A recent PyTorch PR #166788 changed the autograd_function_apply higher-order op API by removing the args_tensor_mask keyword argument.

Currently, Thunder reads this key in the fwd_kwargs dict leading to a KeyError:

args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"])

This PR adds a compatibility layer for both older PyTorch releases (which still provide args_tensor_mask) and newer PyTorch nightly versions (which removed it). When the key is not present, we infer the tensor mask by checking which arguments are TensorProxy instances.

Let's see in CI if further changes needed anywhere else.

Fixes #2803.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a compatibility layer to handle the removal of the args_tensor_mask keyword argument from PyTorch's autograd_function_apply higher-order op API (removed in PyTorch PR #166788). The fix ensures Thunder works with both older PyTorch versions (that provide args_tensor_mask) and newer nightly versions (that removed it).

Key changes:

  • Added conditional logic to check if args_tensor_mask exists in fwd_kwargs before accessing it
  • When not present, infers the tensor mask by checking which arguments are TensorProxy instances

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# NOTE: `args_tensor_mask` was removed in PyTorch PR #166788.
# https://github.com/pytorch/pytorch/pull/166788/
# For backwards compatibility with older PyTorch versions, we check if it exists.
# When not present, we assume all fwd_args are tensors.
Copy link

Copilot AI Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "When not present, we assume all fwd_args are tensors" but the implementation actually checks each argument to determine if it's a TensorProxy. The comment should be updated to accurately reflect what the code does: "When not present, we infer the tensor mask by checking if each argument is a TensorProxy".

Suggested change
# When not present, we assume all fwd_args are tensors.
# When not present, we infer the tensor mask by checking if each argument is a TensorProxy.

Copilot uses AI. Check for mistakes.
@IvanYashchuk
Copy link
Collaborator Author

Closing in favor of #2802.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

args_tensor_mask has been removed from torch.ops.higher_order.autograd_function_apply

1 participant

Comments