Skip to content

Commit 1aaedbc

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][hops] Add xfail tests for side effects (pytorch#168394)
Pull Request resolved: pytorch#168394 Approved by: https://github.com/jansel
1 parent f1c49c9 commit 1aaedbc

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

test/dynamo/test_autograd_function.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# flake8: noqa: B950
33
import copy
44
import math
5+
import unittest
56
from dataclasses import dataclass
67

78
import torch
@@ -1543,6 +1544,43 @@ def f(x, y):
15431544
loss.backward()
15441545
self.assertEqual(x + y, z)
15451546

1547+
@unittest.expectedFailure
1548+
def test_nonlocal_list_mutation_in_autograd_function(self):
1549+
"""Test that nonlocal list mutation in autograd.Function forward is handled correctly."""
1550+
1551+
class SimpleAutogradFunc(torch.autograd.Function):
1552+
@staticmethod
1553+
def forward(ctx, x, z):
1554+
# Simple computation
1555+
o = torch.matmul(x, x) @ x
1556+
out = x.sin()
1557+
# Mutate the nonlocal list
1558+
z.append(out)
1559+
return torch.cos(torch.sin(o)), torch.sin(x)
1560+
1561+
@staticmethod
1562+
def backward(ctx, grad_output1, grad_output2):
1563+
# Simple backward
1564+
return grad_output1 + grad_output2, None
1565+
1566+
def fn(x):
1567+
z = []
1568+
1569+
outs = SimpleAutogradFunc.apply(x, z)
1570+
out1 = outs[0]
1571+
# Check that the extra output pytree handling is done properly
1572+
out2 = outs[-1]
1573+
1574+
return out1 + out2, z[0]
1575+
1576+
x = torch.randn(4, 4, requires_grad=True)
1577+
ref = fn(x)
1578+
1579+
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
1580+
res = opt_fn(x)
1581+
self.assertEqual(ref[0], res[0])
1582+
self.assertEqual(ref[1], res[1])
1583+
15461584

15471585
if __name__ == "__main__":
15481586
from torch._dynamo.test_case import run_tests

test/higher_order_ops/test_invoke_subgraph.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,35 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"):
910910
""",
911911
)
912912

913+
@unittest.expectedFailure
914+
def test_nonlocal_list_mutation_hidden(self):
915+
"""Test that nonlocal list mutation inside nested_compile_region is handled correctly."""
916+
917+
@nested_compile_region
918+
def gn(x, z):
919+
o = torch.matmul(x, x) @ x
920+
out = x.sin()
921+
z.append(out)
922+
return torch.cos(torch.sin(o)), torch.sin(x)
923+
924+
def fn(x):
925+
z = []
926+
927+
outs = gn(x, z)
928+
out1 = outs[0]
929+
# Check that the extra output pytree handling is done properly
930+
out2 = outs[-1]
931+
932+
return out1 + out2, z[0]
933+
934+
x = torch.randn(4, 4, requires_grad=True)
935+
ref = fn(x)
936+
937+
opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
938+
res = opt_fn(x)
939+
self.assertEqual(ref[0], res[0])
940+
self.assertEqual(ref[1], res[1])
941+
913942
@inductor_config.patch("fx_graph_cache", False)
914943
def test_view_to_reshape(self):
915944
@nested_compile_region

0 commit comments

Comments
 (0)