Skip to content

Commit 27cfdd9

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[export] Return more information from tracing context in graph capture. (pytorch#166775)
Summary: as title, we should return an entire tracing_context object instead of fake_mode only, since tracing context should contain full set of information. Test Plan: pytest test/export/test_experimental.py Pull Request resolved: pytorch#166775 Approved by: https://github.com/tugsbayasgalan
1 parent 01d8d85 commit 27cfdd9

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

test/export/test_experimental.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,25 @@ def forward(self, args_0):
522522
)
523523
self.assertEqual(ep(*inps), MyModel()(*inps))
524524

525+
def test_dynamo_graph_capture_full_tracing_context(self) -> None:
526+
class Foo(torch.nn.Module):
527+
def forward(self, x):
528+
return x + x.shape[0]
529+
530+
foo = Foo()
531+
532+
def make_inputs(b: int):
533+
ret = (torch.randn(b, 3),)
534+
torch._dynamo.mark_dynamic(ret[0], 0)
535+
return ret
536+
537+
trace_inputs = make_inputs(2)
538+
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
539+
test_inputs = make_inputs(3)
540+
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
541+
self.assertIsNotNone(gm.meta["tracing_context"].fake_mode)
542+
self.assertEqual(len(gm.meta["tracing_context"].tensor_to_context), 1)
543+
525544
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
526545
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
527546
class DummyOp(torch.autograd.Function):

torch/_dynamo/functional_export.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch._dynamo.exc import UserErrorType
1717
from torch._dynamo.utils import dynamo_timed, get_metrics_context
1818
from torch._export.utils import _compiling_state_context
19+
from torch._guards import TracingContext
1920
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
2021
from torch.fx import Node
2122
from torch.fx.experimental.proxy_tensor import make_fx
@@ -650,6 +651,10 @@ def inner(*args: Any, **kwargs: Any) -> Any:
650651
)
651652
assert out.backend_input is not None
652653
graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined]
654+
graph_module.meta["fake_mode"].allow_non_fake_inputs = True
655+
tracing_context = TracingContext(graph_module.meta["fake_mode"])
656+
tracing_context.tensor_to_context = out.backend_input.tensor_to_context # type: ignore[attr-defined]
657+
graph_module.meta["tracing_context"] = tracing_context
653658
return graph_module
654659

655660
return inner

0 commit comments

Comments
 (0)