File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed
Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -940,7 +940,15 @@ def _generate_random_str_id() -> str:
940940 length = 5
941941 return "" .join (secrets .choice (string .ascii_lowercase ) for _ in range (length ))
942942
943- args_tensor_mask = unwrap (fwd_kwargs ["args_tensor_mask" ])
943+ # NOTE: `args_tensor_mask` was removed in PyTorch PR #166788.
944+ # https://github.com/pytorch/pytorch/pull/166788/
945+ # For backwards compatibility with older PyTorch versions, we check if it exists.
946+ # When not present, we assume all fwd_args are tensors.
947+ if "args_tensor_mask" in fwd_kwargs :
948+ args_tensor_mask = unwrap (fwd_kwargs ["args_tensor_mask" ])
949+ else :
950+ # For new PyTorch versions without args_tensor_mask, treat all args as potential tensors
951+ args_tensor_mask = tuple (isinstance (unwrap (arg ), TensorProxy ) for arg in fwd_args )
944952 # TODO(crcrpar): Think about making use of `non_differentiable_idx`
945953 # note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087
946954 # non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")
You can’t perform that action at this time.
0 commit comments