Skip to content

Commit bc3972b

Browse files
yushangdipytorchmergebot
authored andcommitted
[reland] Add stack_trace on make_fx (pytorch#155486)
Summary: Previosuly, we only add stack trace in class _ModuleStackTracer(PythonKeyTracer) for non-strict export. I moved this stack trace logic to the parent class PythonKeyTracer, this way the graph traced from Module using make_fx will have stack_trace as well. Motivation: we've observed some uses cases where users first use make_fx on the Module, and then run export on the resulting graph. If the result of make_fx doesn't have stack trace, the stack trace information is lost. **User needs to turn this on by passing in `stack_trace=True` to make_fx. We don't make this the default option since this might increase inductor compilation time (`make_fx` is used in inductor to trace graph patterns for pattern matching). It's also turned on if `_inductor.config.trace.enabled` is True.** **preserving stack trace is on by default for ModuleStackTracer, which is used for non-strict export.** Test Plan: ``` buck run test:test_export -- -r test_stack_trace buck run fbcode//caffe2/test/dynamo:test_dynamo -- -k test_autocast_ordering ``` Rollback Plan: Differential Revision: D76298692 Pull Request resolved: pytorch#155486 Approved by: https://github.com/angelayi, https://github.com/zou3519
1 parent 9bd0830 commit bc3972b

File tree

2 files changed

+98
-30
lines changed

2 files changed

+98
-30
lines changed

test/export/test_export.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11473,6 +11473,59 @@ def forward(self, x):
1147311473
)
1147411474
)
1147511475

11476+
def test_stack_trace_make_fx(self):
11477+
class Foo(torch.nn.Module):
11478+
def __init__(self) -> None:
11479+
super().__init__()
11480+
self.linear = torch.nn.Linear(4, 4)
11481+
11482+
def forward(self, x):
11483+
x = self.linear(x)
11484+
x *= 2.0
11485+
return x
11486+
11487+
inp = torch.randn(4, 4)
11488+
gm = torch.fx.experimental.proxy_tensor.make_fx(Foo(), stack_trace=True)(
11489+
inp,
11490+
)
11491+
11492+
# check correct lines are in stack trace
11493+
trace_mul = [node for node in gm.graph.nodes if node.name == "mul_"][
11494+
0
11495+
].meta.get("stack_trace", "")
11496+
self.assertTrue(
11497+
re.search(r"test_export.py.*in forward\n.*x \*= 2.0", trace_mul)
11498+
)
11499+
trace_addmm = [node for node in gm.graph.nodes if node.name in ["addmm", "t"]][
11500+
0
11501+
].meta.get("stack_trace", "")
11502+
self.assertTrue(
11503+
re.search(
11504+
r"test_export.py.*in forward\n.*x = self.linear\(x\)", trace_addmm
11505+
)
11506+
)
11507+
11508+
# check correct lines are still in stack trace after export
11509+
ep = export(
11510+
gm,
11511+
(torch.randn(4, 4),),
11512+
).run_decompositions({})
11513+
# check correct lines are in stack trace
11514+
trace_mul = [node for node in ep.graph.nodes if node.name == "mul"][0].meta.get(
11515+
"stack_trace", ""
11516+
)
11517+
self.assertTrue(
11518+
re.search(r"test_export.py.*in forward\n.*x \*= 2.0", trace_mul)
11519+
)
11520+
trace_addmm = [
11521+
node for node in ep.graph.nodes if node.name in ["addmm", "linear"]
11522+
][0].meta.get("stack_trace", "")
11523+
self.assertTrue(
11524+
re.search(
11525+
r"test_export.py.*in forward\n.*x = self.linear\(x\)", trace_addmm
11526+
)
11527+
)
11528+
1147611529
@testing.expectedFailureSerDerNonStrict # register_constant needs to handle serialization
1147711530
@testing.expectedFailureSerDer # register_constant needs to handle serialization
1147811531
def test_register_constant(self):

torch/fx/experimental/proxy_tensor.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,7 @@ class PythonKeyTracer(Tracer):
10251025
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
10261026
torch_fn_counts: dict[OpOverload, int]
10271027
enable_thunkify: bool = False
1028+
stack_trace: bool = False
10281029

10291030
def __init__(self) -> None:
10301031
super().__init__(autowrap_modules=()) # type: ignore[arg-type]
@@ -1110,6 +1111,39 @@ def create_node(
11101111
) -> torch.fx.Node:
11111112
node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type]
11121113

1114+
# stack_trace
1115+
if (
1116+
self.stack_trace
1117+
and "stack_trace" not in node.meta
1118+
and node.op not in ["placeholder", "output"]
1119+
):
1120+
user_frame_summary = CapturedTraceback.extract().summary()
1121+
if user_frame_summary:
1122+
# we retain frames from forward() calls, or ops
1123+
# located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap)
1124+
stack_trace = [
1125+
frame
1126+
for frame in user_frame_summary
1127+
if (
1128+
frame.name == "forward"
1129+
or frame.filename.endswith("torch/__init__.py")
1130+
)
1131+
]
1132+
# filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py
1133+
# this is hardcoded, but leads to a much cleaner stack trace
1134+
stack_trace = [
1135+
frame
1136+
for frame in stack_trace
1137+
if not frame.filename.endswith(
1138+
("fx/_symbolic_trace.py", "export/_trace.py")
1139+
)
1140+
]
1141+
if (
1142+
stack_trace
1143+
): # empty list for strict mode, dynamo should handle stack_trace
1144+
stack_trace = traceback.StackSummary.from_list(stack_trace)
1145+
node.meta["stack_trace"] = "".join(stack_trace.format()).strip()
1146+
11131147
def map_fn(v: Any) -> Optional[_ExtractValType]:
11141148
if not isinstance(v, torch.fx.Node) or "val" not in v.meta:
11151149
return None
@@ -1659,6 +1693,7 @@ class _ModuleStackTracer(PythonKeyTracer):
16591693

16601694
def __init__(self, scope_root: GraphModule) -> None:
16611695
super().__init__()
1696+
self.stack_trace = True
16621697
self.scope_root = scope_root
16631698
self.enable_attr_proxy = False
16641699
self.submodule_paths = {}
@@ -1909,36 +1944,6 @@ def create_node(self, *args: object, **kwargs: object) -> fx.node.Node:
19091944
f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}",
19101945
)
19111946

1912-
# stack_trace
1913-
if "stack_trace" not in node.meta and node.op not in ["placeholder", "output"]:
1914-
user_frame_summary = CapturedTraceback.extract().summary()
1915-
if user_frame_summary:
1916-
# we retain frames from forward() calls, or ops
1917-
# located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap)
1918-
stack_trace = [
1919-
frame
1920-
for frame in user_frame_summary
1921-
if (
1922-
frame.name == "forward"
1923-
or frame.filename.endswith("torch/__init__.py")
1924-
)
1925-
]
1926-
# filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py
1927-
# this is hardcoded, but leads to a much cleaner stack trace
1928-
stack_trace = [
1929-
frame
1930-
for frame in stack_trace
1931-
if not (
1932-
frame.filename.endswith("fx/_symbolic_trace.py")
1933-
or frame.filename.endswith("export/_trace.py")
1934-
)
1935-
]
1936-
if (
1937-
stack_trace
1938-
): # empty list for strict mode, dynamo should handle stack_trace
1939-
stack_trace = traceback.StackSummary.from_list(stack_trace)
1940-
node.meta["stack_trace"] = "".join(stack_trace.format()).strip()
1941-
19421947
return node
19431948

19441949

@@ -1952,6 +1957,7 @@ def __init__(
19521957
record_module_stack: bool,
19531958
_allow_fake_constant: bool,
19541959
_error_on_data_dependent_ops: bool,
1960+
stack_trace: bool = False,
19551961
) -> None:
19561962
# Configurations that are used to initialize the context managers and their states.
19571963
# Should not modify them during tracing.
@@ -1982,6 +1988,7 @@ def __init__(
19821988
self.torch_fn_metadata_mode: Union[
19831989
nullcontext, TorchFunctionMetadataMode
19841990
] = nullcontext()
1991+
self.stack_trace = stack_trace
19851992

19861993
def _checkpoint_modes(self) -> list[Any]:
19871994
return [
@@ -2020,9 +2027,11 @@ def _init_modes_from_inputs(
20202027

20212028
if hasattr(f, "_orig_mod") and self.record_module_stack:
20222029
scope_root = f._orig_mod
2030+
# _ModuleStackTracer always try to preserve stack trace
20232031
self.fx_tracer = _ModuleStackTracer(scope_root)
20242032
else:
20252033
self.fx_tracer = PythonKeyTracer()
2034+
self.fx_tracer.stack_trace = self.stack_trace
20262035

20272036
if self.tracing_mode == "fake":
20282037
import torch._dynamo
@@ -2274,15 +2283,20 @@ def make_fx(
22742283
record_module_stack: bool = False,
22752284
_allow_fake_constant: bool = False,
22762285
_error_on_data_dependent_ops: bool = True,
2286+
stack_trace: bool = False,
22772287
) -> Callable[..., GraphModule]:
22782288
"""
22792289
Given a function f, return a new function which when executed with valid
22802290
arguments to f, returns an FX GraphModule representing the set of operations that
22812291
were executed during the course of execution.
2292+
2293+
If stack_trace is True, the stack_trace will be preserved on node.meta["stack_trace"]
22822294
"""
22832295

22842296
assert tracing_mode in ["real", "fake", "symbolic"]
22852297

2298+
from torch._inductor import config
2299+
22862300
make_fx_tracer = _MakefxTracer(
22872301
decomposition_table,
22882302
tracing_mode,
@@ -2291,6 +2305,7 @@ def make_fx(
22912305
record_module_stack,
22922306
_allow_fake_constant,
22932307
_error_on_data_dependent_ops,
2308+
stack_trace=stack_trace or config.trace.enabled,
22942309
)
22952310

22962311
@functools.wraps(f)

0 commit comments

Comments
 (0)