Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit d777fcc

Browse files
authored
fix functionalize(): properly propagate input mutations (#654)
1 parent deb706d commit d777fcc

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

functorch/_src/eager_transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,10 +1224,10 @@ def wrapped(*args, **kwargs):
12241224
func_args = _wrap_all_tensors_to_functional(args, func_level)
12251225
func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)
12261226

1227-
flattened_unwrapped_args = tree_flatten(args)
1228-
flattened_wrapped_args = tree_flatten(func_args)
1229-
flattened_unwrapped_kwargs = tree_flatten(kwargs)
1230-
flattened_wrapped_kwargs = tree_flatten(func_kwargs)
1227+
flattened_unwrapped_args, _ = tree_flatten(args)
1228+
flattened_wrapped_args, _ = tree_flatten(func_args)
1229+
flattened_unwrapped_kwargs, _ = tree_flatten(kwargs)
1230+
flattened_wrapped_kwargs, _ = tree_flatten(func_kwargs)
12311231

12321232
func_outputs = func(*func_args, **func_kwargs)
12331233
outputs = _unwrap_all_tensors_from_functional(func_outputs)

test/test_eager_transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,6 +2776,9 @@ def _check_functionalize_correctness(self, f, inpt):
27762776
# Check that outputs are the same
27772777
self.assertEqual(actual_outputs, expected_outputs)
27782778

2779+
# Inputs might have been mutated by f: check that they were mutated properly
2780+
self.assertEqual(inpt1, inpt2)
2781+
27792782
def test_simple_view(self, device):
27802783
def f(x: torch.Tensor) -> torch.Tensor:
27812784
tmp = torch.ones(2, device=device)

0 commit comments

Comments
 (0)