Skip to content

Commit 11f73d7

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[export] Downgrade captured buffers as normal constants. (pytorch#166777)
Summary: make_fx() will register tensor constants as new buffers while tracing a shuffle graph for dynamo graph capture. This breaks the invariance that the resulting graph looks identical to the original eager model in terms of state dict. So we need to de-register the buffers and set them as plain tensor constants. Test Plan: pytest test/export/test_experimental.py Pull Request resolved: pytorch#166777 Approved by: https://github.com/tugsbayasgalan ghstack dependencies: pytorch#166775, pytorch#166776
1 parent 7d1b976 commit 11f73d7

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

test/export/test_experimental.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,26 @@ def forward(self, args_0):
579579
)
580580
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
581581

582+
def test_dynamo_graph_capture_with_tensor_constant(self):
583+
outer = torch.randn(2, 3)
584+
585+
class MyModel(torch.nn.Module):
586+
def forward(self, x):
587+
z = x + outer
588+
return z
589+
590+
foo = MyModel()
591+
592+
def make_inputs():
593+
return (torch.randn(2, 3),)
594+
595+
trace_inputs = make_inputs()
596+
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
597+
test_inputs = make_inputs()
598+
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
599+
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
600+
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
601+
582602
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
583603
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
584604
class DummyOp(torch.autograd.Function):

torch/_dynamo/functional_export.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,14 @@ def _suggest_or_raise_constraint_violation(
450450
raise constraint_violation_error
451451

452452

453+
def _normalize_shuffle_graph(shuffle_gm: torch.fx.GraphModule) -> None:
454+
shuffle_gm.graph.eliminate_dead_code()
455+
shuffle_gm.recompile()
456+
for name, buffer in list(shuffle_gm.named_buffers()):
457+
delattr(shuffle_gm, name)
458+
setattr(shuffle_gm, name, buffer)
459+
460+
453461
@dataclass(frozen=True)
454462
class PyTreeifyOutput:
455463
graph_module: torch.fx.GraphModule
@@ -526,8 +534,7 @@ def backend_dummy(*example_inputs):
526534
in_shuffle_graph = make_fx(
527535
InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True
528536
)(*flat_real_args)
529-
in_shuffle_graph.graph.eliminate_dead_code()
530-
in_shuffle_graph.recompile()
537+
_normalize_shuffle_graph(in_shuffle_graph)
531538

532539
output_node = next(iter(reversed(backend_input.graph_module.graph.nodes)))
533540

@@ -575,8 +582,7 @@ def backend_dummy(*example_inputs):
575582
out_shuffle_graph = make_fx(
576583
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
577584
)(*flat_out_shuffle_args)
578-
out_shuffle_graph.graph.eliminate_dead_code()
579-
out_shuffle_graph.recompile()
585+
_normalize_shuffle_graph(out_shuffle_graph)
580586

581587
assert out_shuffle.out_spec is not None
582588
return PyTreeifyOutput(

0 commit comments

Comments
 (0)