Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 89baedd

Browse files
authored
Trace the backward pass assuming contiguous tensors (#536)
1 parent d777fcc commit 89baedd

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

functorch/_src/aot_autograd.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def forward(ctx, *flat_tensor_args):
135135
with torch.set_grad_enabled(grad_state):
136136
out = flat_fn(*flat_tensor_args)
137137
out = pytree.tree_map(
138-
lambda x: x.detach() if isinstance(x, Tensor) else x, out
138+
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
139139
)
140140

141141
if isinstance(out, (list, tuple)):
@@ -164,9 +164,8 @@ def forward(ctx, *flat_tensor_args):
164164

165165
@staticmethod
166166
def backward(ctx, *flat_args):
167-
# hmm... this doesn't feel right. todo
168-
# contiguous_args = [t.contiguous() for t in flat_args]
169-
contiguous_args = [t for t in flat_args]
167+
contiguous_args = [t.contiguous() for t in flat_args]
168+
# contiguous_args = [t for t in flat_args]
170169
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
171170
return tuple(out)
172171

test/test_pythonkey.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,19 @@ def f(a, b, c, d):
527527
self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))
528528

529529

530+
class TestContiguous(TestCase):
531+
def test_contiguous(self):
532+
# The test simulates the condition where transpose followed by view
533+
# happens in the backward pass.
534+
# https://discuss.pytorch.org/t/error-on-transpose-and-view/434
535+
def f(x):
536+
return x.view(2, 3).t()
537+
538+
inp = torch.randn(6, requires_grad=True)
539+
out = aot_function(f, nop)(inp)
540+
torch.autograd.grad(out, inp, torch.randn(3, 2))
541+
542+
530543
only_for = ("cpu")
531544
instantiate_device_type_tests(
532545
TestPythonKey,

0 commit comments

Comments
 (0)