Skip to content

Commit 378a55c

Browse files
authored
Only print dde partial fx graph for export (pytorch#153218)
* Only print dde partial fx graph for export Get pytorch#149831 into 2.7.1 * [dynamo] Add test to ensure we don't print fx graph upon data dependent graph break This adds a regression test for pytorch#149831, also as part of getting it cherry-picked into 2.7.1. ghstack-source-id: fedc9ea Pull Request resolved: pytorch#153416
1 parent 800aa04 commit 378a55c

File tree

4 files changed

+46
-17
lines changed

4 files changed

+46
-17
lines changed

test/dynamo/test_repros.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import itertools
1616
import os
1717
import random
18+
import sys
1819
import types
1920
import typing
2021
import unittest
@@ -6870,6 +6871,31 @@ def fn(x):
68706871
x = torch.randn(4)
68716872
self.assertEqual(fn(x), opt_fn(x))
68726873

6874+
def test_data_dependent_error_log_no_print(self):
6875+
# This is a regression test case for
6876+
# https://github.com/pytorch/pytorch/pull/149831
6877+
from io import StringIO
6878+
6879+
capturedOutput = StringIO()
6880+
sys.stderr = capturedOutput
6881+
6882+
@torch.compile(fullgraph=True)
6883+
def func(a):
6884+
if a.sum() > 0:
6885+
return a + 1
6886+
return a + 2
6887+
6888+
a = torch.rand(10, 10)
6889+
try:
6890+
func(a)
6891+
except Exception:
6892+
pass
6893+
sys.stderr = sys.__stderr__
6894+
6895+
# Make sure we don't _print_ out the graph module.
6896+
output = capturedOutput.getvalue()
6897+
self.assertNotIn("class GraphModule", output)
6898+
68736899

68746900
instantiate_parametrized_tests(ReproTests)
68756901

torch/_dynamo/symbolic_convert.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,15 +1342,13 @@ def run(self):
13421342
raise
13431343
except RuntimeError as e:
13441344
if hasattr(e, "msg") and "Data-dependent" in e.msg:
1345-
print(
1346-
"\n"
1347-
+ torch.fx.GraphModule(
1348-
self.output.nn_modules, self.output.graph
1349-
).print_readable(
1350-
print_output=False, include_stride=True, include_device=True
1351-
),
1352-
file=sys.stderr,
1345+
readable_graph = torch.fx.GraphModule(
1346+
self.output.nn_modules, self.output.graph
1347+
).print_readable(
1348+
print_output=False, include_stride=True, include_device=True
13531349
)
1350+
e.partial_fx_graph = readable_graph # type: ignore[attr-defined]
1351+
raise
13541352

13551353
raise
13561354
except Exception as e:

torch/export/_trace.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import inspect
66
import logging
77
import re
8+
import sys
89
import time
910
import warnings
1011
from contextlib import contextmanager, nullcontext
@@ -1089,6 +1090,13 @@ def wrapper(*args, **kwargs):
10891090
message=str(e),
10901091
flags=_EXPORT_FLAGS,
10911092
)
1093+
1094+
if hasattr(e, "partial_fx_graph"):
1095+
print(
1096+
e.partial_fx_graph,
1097+
file=sys.stderr,
1098+
)
1099+
10921100
raise e
10931101
finally:
10941102
_EXPORT_FLAGS = None

torch/fx/_symbolic_trace.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import inspect
88
import math
99
import os
10-
import sys
1110
import warnings
1211
from itertools import chain
1312
from types import CodeType, FunctionType, ModuleType
@@ -843,14 +842,12 @@ def forward(*args, **kwargs):
843842
self.submodule_paths = None
844843
except RuntimeError as e:
845844
if isinstance(e.args[0], str) and "data-dependent" in e.args[0]:
846-
print(
847-
"\n"
848-
+ self.graph.python_code(
849-
root_module="self",
850-
verbose=True,
851-
).src,
852-
file=sys.stderr,
853-
)
845+
partial_fx_graph = self.graph.python_code(
846+
root_module="self",
847+
verbose=True,
848+
).src
849+
e.partial_fx_graph = partial_fx_graph # type: ignore[attr-defined]
850+
raise
854851

855852
raise
856853
finally:

0 commit comments

Comments
 (0)