File tree Expand file tree Collapse file tree 4 files changed +46
-17
lines changed Expand file tree Collapse file tree 4 files changed +46
-17
lines changed Original file line number Diff line number Diff line change 1515import itertools
1616import os
1717import random
18+ import sys
1819import types
1920import typing
2021import unittest
@@ -6870,6 +6871,31 @@ def fn(x):
68706871 x = torch .randn (4 )
68716872 self .assertEqual (fn (x ), opt_fn (x ))
68726873
6874+ def test_data_dependent_error_log_no_print (self ):
6875+ # This is a regression test case for
6876+ # https://github.com/pytorch/pytorch/pull/149831
6877+ from io import StringIO
6878+
6879+ capturedOutput = StringIO ()
6880+ sys .stderr = capturedOutput
6881+
6882+ @torch .compile (fullgraph = True )
6883+ def func (a ):
6884+ if a .sum () > 0 :
6885+ return a + 1
6886+ return a + 2
6887+
6888+ a = torch .rand (10 , 10 )
6889+ try :
6890+ func (a )
6891+ except Exception :
6892+ pass
6893+ sys .stderr = sys .__stderr__
6894+
6895+ # Make sure we don't _print_ out the graph module.
6896+ output = capturedOutput .getvalue ()
6897+ self .assertNotIn ("class GraphModule" , output )
6898+
68736899
68746900instantiate_parametrized_tests (ReproTests )
68756901
Original file line number Diff line number Diff line change @@ -1342,15 +1342,13 @@ def run(self):
13421342 raise
13431343 except RuntimeError as e :
13441344 if hasattr (e , "msg" ) and "Data-dependent" in e .msg :
1345- print (
1346- "\n "
1347- + torch .fx .GraphModule (
1348- self .output .nn_modules , self .output .graph
1349- ).print_readable (
1350- print_output = False , include_stride = True , include_device = True
1351- ),
1352- file = sys .stderr ,
1345+ readable_graph = torch .fx .GraphModule (
1346+ self .output .nn_modules , self .output .graph
1347+ ).print_readable (
1348+ print_output = False , include_stride = True , include_device = True
13531349 )
1350+ e .partial_fx_graph = readable_graph # type: ignore[attr-defined]
1351+ raise
13541352
13551353 raise
13561354 except Exception as e :
Original file line number Diff line number Diff line change 55import inspect
66import logging
77import re
8+ import sys
89import time
910import warnings
1011from contextlib import contextmanager , nullcontext
@@ -1089,6 +1090,13 @@ def wrapper(*args, **kwargs):
10891090 message = str (e ),
10901091 flags = _EXPORT_FLAGS ,
10911092 )
1093+
1094+ if hasattr (e , "partial_fx_graph" ):
1095+ print (
1096+ e .partial_fx_graph ,
1097+ file = sys .stderr ,
1098+ )
1099+
10921100 raise e
10931101 finally :
10941102 _EXPORT_FLAGS = None
Original file line number Diff line number Diff line change 77import inspect
88import math
99import os
10- import sys
1110import warnings
1211from itertools import chain
1312from types import CodeType , FunctionType , ModuleType
@@ -843,14 +842,12 @@ def forward(*args, **kwargs):
843842 self .submodule_paths = None
844843 except RuntimeError as e :
845844 if isinstance (e .args [0 ], str ) and "data-dependent" in e .args [0 ]:
846- print (
847- "\n "
848- + self .graph .python_code (
849- root_module = "self" ,
850- verbose = True ,
851- ).src ,
852- file = sys .stderr ,
853- )
845+ partial_fx_graph = self .graph .python_code (
846+ root_module = "self" ,
847+ verbose = True ,
848+ ).src
849+ e .partial_fx_graph = partial_fx_graph # type: ignore[attr-defined]
850+ raise
854851
855852 raise
856853 finally :
You can’t perform that action at this time.
0 commit comments