Skip to content

Commit 37efaf4

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca][api] config api shouldn't error with optimize_assert (pytorch#153193)
Toggling on `torch._dynamo.config.compiled_autograd = True` was erroring export (optimize_assert didn't have `rebuild_ctx` defined). Separately add a way to `rebuild_ctx` for `optimize_assert` since it is a public API. Pull Request resolved: pytorch#153193 Approved by: https://github.com/jansel
1 parent a4459cd commit 37efaf4

File tree

2 files changed

+65
-63
lines changed

2 files changed

+65
-63
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 50 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -660,83 +660,72 @@ def fn():
660660

661661
self.check_output_and_recompiles(fn)
662662

663-
def test_torch_compile_api_inductor(self):
664-
def fn():
665-
torch.manual_seed(123)
666-
model = torch.nn.Sequential(
667-
torch.nn.Linear(4, 4),
668-
torch.nn.Sigmoid(),
669-
)
670-
663+
@parametrize("api", ("compile", "optimize"))
664+
@parametrize("backend", ("eager", "aot_eager", "inductor"))
665+
def test_compile_api(self, api, backend):
666+
def wrap(fn, backend):
667+
if api == "compile":
668+
return torch.compile(fn, backend=backend)
669+
elif api == "optimize":
670+
return torch._dynamo.optimize(backend)(fn)
671+
672+
def fn(model, inputs):
671673
res = []
672-
for _ in range(3):
673-
x = torch.randn([1, 4])
674-
675-
result = model(x).sum()
674+
for inp in inputs:
675+
result = model(inp).sum()
676676
result.backward()
677677
res.append(model[0].weight.grad)
678678
res.append(model[0].bias.grad)
679679
model.zero_grad()
680680
return res
681681

682-
expected = fn()
683-
with config.patch(compiled_autograd=True):
684-
compiled_fn = torch.compile(fn)
685-
actual = compiled_fn()
686-
self.assertEqual(expected, actual)
687-
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
688-
689-
def test_torch_compile_api_aot_eager(self):
690-
def fn():
691-
torch.manual_seed(123)
692-
model = torch.nn.Sequential(
693-
torch.nn.Linear(4, 4),
694-
torch.nn.Sigmoid(),
695-
)
696-
697-
res = []
698-
for _ in range(3):
699-
x = torch.randn([1, 4])
700-
701-
result = model(x).sum()
702-
result.backward()
703-
res.append(model[0].weight.grad)
704-
res.append(model[0].bias.grad)
705-
model.zero_grad()
706-
return res
682+
torch.manual_seed(123)
683+
model = torch.nn.Sequential(
684+
torch.nn.Linear(4, 4),
685+
torch.nn.Sigmoid(),
686+
)
687+
inputs = [
688+
torch.randn([1, 4]),
689+
torch.randn([2, 4]),
690+
torch.randn([3, 4]),
691+
]
707692

708-
expected = fn()
693+
expected = fn(model, inputs)
709694
with config.patch(compiled_autograd=True):
710-
compiled_fn = torch.compile(fn, backend="aot_eager")
711-
actual = compiled_fn()
695+
compiled_fn = wrap(fn, backend)
696+
actual = compiled_fn(model, inputs)
712697
self.assertEqual(expected, actual)
713-
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
698+
self.assertEqual(counters["compiled_autograd"]["captures"], 2)
714699

715-
def test_torch_compile_api_eager(self):
716-
def fn():
717-
torch.manual_seed(123)
718-
model = torch.nn.Sequential(
719-
torch.nn.Linear(4, 4),
720-
torch.nn.Sigmoid(),
721-
)
700+
@parametrize("backend", ("eager", "aot_eager", "inductor"))
701+
def test_optimize_assert(self, backend):
702+
# can be merged into the test above once we support
703+
# no graph break on .backward
722704

723-
res = []
724-
for _ in range(3):
725-
x = torch.randn([1, 4])
705+
def fn(model, inp):
706+
# NOTE: not calling .backward in the compiled fn
707+
return model(inp).sum()
726708

727-
result = model(x).sum()
728-
result.backward()
729-
res.append(model[0].weight.grad)
730-
res.append(model[0].bias.grad)
731-
model.zero_grad()
732-
return res
709+
torch.manual_seed(123)
710+
model = torch.nn.Sequential(
711+
torch.nn.Linear(4, 4),
712+
torch.nn.Sigmoid(),
713+
)
714+
inp = torch.randn([1, 4])
733715

734-
expected = fn()
716+
out = fn(model, inp)
717+
out.backward()
718+
expected = [p.grad for p in model.parameters()]
719+
model.zero_grad()
735720
with config.patch(compiled_autograd=True):
736-
compiled_fn = torch.compile(fn, backend="eager")
737-
actual = compiled_fn()
721+
compiled_fn = torch._dynamo.optimize_assert(backend)(fn)
722+
723+
# should not error due to undefined `rebuild_ctx`
724+
out = compiled_fn(model, inp)
725+
out.backward()
726+
actual = [p.grad for p in model.parameters()]
738727
self.assertEqual(expected, actual)
739-
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
728+
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
740729

741730
def test_multiple_torch_compile(self):
742731
model = torch.nn.Sequential(

torch/_dynamo/eval_frame.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,14 +1881,27 @@ def graph_with_interpreter(*args):
18811881
return inner
18821882

18831883

1884-
def optimize_assert(
1884+
def optimize_assert(*args, **kwargs):
1885+
if "rebuild_ctx" in kwargs and kwargs["rebuild_ctx"] is not None:
1886+
# called from optimize
1887+
rebuild_ctx = kwargs["rebuild_ctx"]
1888+
del kwargs["rebuild_ctx"]
1889+
else:
1890+
1891+
def rebuild_ctx():
1892+
return optimize_assert(*args, **kwargs)
1893+
1894+
return _optimize_assert(rebuild_ctx, *args, **kwargs)
1895+
1896+
1897+
def _optimize_assert(
1898+
rebuild_ctx: Callable[[], OptimizeContext],
18851899
backend,
18861900
*,
18871901
hooks=Hooks(None, None, None),
18881902
export=False,
18891903
export_constraints=None,
18901904
dynamic=None,
1891-
rebuild_ctx=None,
18921905
):
18931906
"""
18941907
The same as `torch._dynamo.optimize(backend, nopython=True)`

0 commit comments

Comments
 (0)