Fix KeyError for args_tensor_mask removed in PyTorch nightly#2804
Fix KeyError for args_tensor_mask removed in PyTorch nightly#2804IvanYashchuk wants to merge 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a compatibility layer to handle the removal of the args_tensor_mask keyword argument from PyTorch's autograd_function_apply higher-order op API (removed in PyTorch PR #166788). The fix ensures Thunder works with both older PyTorch versions (that provide args_tensor_mask) and newer nightly versions (that removed it).
Key changes:
- Added conditional logic to check if
args_tensor_maskexists infwd_kwargsbefore accessing it - When not present, infers the tensor mask by checking which arguments are
TensorProxyinstances
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # NOTE: `args_tensor_mask` was removed in PyTorch PR #166788. | ||
| # https://github.com/pytorch/pytorch/pull/166788/ | ||
| # For backwards compatibility with older PyTorch versions, we check if it exists. | ||
| # When not present, we assume all fwd_args are tensors. |
There was a problem hiding this comment.
The comment states "When not present, we assume all fwd_args are tensors" but the implementation actually checks each argument to determine if it's a TensorProxy. The comment should be updated to accurately reflect what the code does: "When not present, we infer the tensor mask by checking if each argument is a TensorProxy".
| # When not present, we assume all fwd_args are tensors. | |
| # When not present, we infer the tensor mask by checking if each argument is a TensorProxy. |
|
Closing in favor of #2802. |
A recent PyTorch PR #166788 changed the autograd_function_apply higher-order op API by removing the args_tensor_mask keyword argument.
Currently, Thunder reads this key in the fwd_kwargs dict leading to a KeyError:
lightning-thunder/thunder/core/jit_ext.py
Line 943 in fb989d4
This PR adds a compatibility layer for both older PyTorch releases (which still provide args_tensor_mask) and newer PyTorch nightly versions (which removed it). When the key is not present, we infer the tensor mask by checking which arguments are TensorProxy instances.
Let's see in CI if further changes needed anywhere else.
Fixes #2803.