Skip to content

Commit 876824f

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Make inline tests to use new exporter and fix some issues around it (pytorch#162682)
inline_and_install_module export variant is our long term state so it is better to use the new tracer for this. It also uncovered bunch of minor bugs because with inline_and_install_module, the nn_module_stack generation is changed a bit. Differential Revision: [D82478648](https://our.internmc.facebook.com/intern/diff/D82478648) Pull Request resolved: pytorch#162682 Approved by: https://github.com/zhxchen17 ghstack dependencies: pytorch#162557, pytorch#162558, pytorch#162559
1 parent a89d5e9 commit 876824f

File tree

2 files changed

+93
-19
lines changed

2 files changed

+93
-19
lines changed

test/export/test_export_with_inline_and_install.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import unittest
55

6-
from torch._dynamo import config
6+
from torch._dynamo import config as dynamo_config
77
from torch._dynamo.testing import make_test_cls_with_patches
8+
from torch._export import config as export_config
89

910

1011
try:
@@ -44,8 +45,9 @@ def make_dynamic_cls(cls):
4445
cls_a,
4546
cls_prefix,
4647
"",
47-
(config, "install_free_tensors", True),
48-
(config, "inline_inbuilt_nn_modules", True),
48+
(export_config, "use_new_tracer_experimental", True),
49+
(dynamo_config, "install_free_tensors", True),
50+
(dynamo_config, "inline_inbuilt_nn_modules", True),
4951
xfail_prop="_expected_failure_inline_and_install",
5052
)
5153

@@ -71,7 +73,30 @@ def make_dynamic_cls(cls):
7173
unittest.expectedFailure(
7274
InlineAndInstallStrictExportTestExport.test_buffer_util_inline_and_install_strict # noqa: F821
7375
)
74-
76+
# this is because we can't preserve stacktrace
77+
unittest.expectedFailure(
78+
InlineAndInstallStrictExportTestExport.test_stack_trace_make_fx_inline_and_install_strict # noqa: F821
79+
)
80+
# this is because we marked unlift hooks to be dynamo skip traced
81+
unittest.expectedFailure(
82+
InlineAndInstallStrictExportTestExport.test_custom_tag_metadata_re_export_inline_and_install_strict # noqa: F821
83+
)
84+
unittest.expectedFailure(
85+
InlineAndInstallStrictExportTestExport.test_from_node_metadata_export_inline_and_install_strict # noqa: F821
86+
)
87+
unittest.expectedFailure(
88+
InlineAndInstallStrictExportTestExport.test_module_inline_and_install_strict # noqa: F821
89+
)
90+
unittest.expectedFailure(
91+
InlineAndInstallStrictExportTestExport.test_module_with_dict_container_inp_out_inline_and_install_strict # noqa: F821
92+
)
93+
unittest.expectedFailure(
94+
InlineAndInstallStrictExportTestExport.test_retrace_pre_autograd_inline_and_install_strict # noqa: F821
95+
)
96+
# this is because detect leak test has export root
97+
unittest.expectedFailure(
98+
InlineAndInstallStrictExportTestExport.test_detect_leak_strict_inline_and_install_strict # noqa: F821
99+
)
75100

76101
if __name__ == "__main__":
77102
from torch._dynamo.test_case import run_tests

torch/_dynamo/functional_export.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,63 @@ def post_process_error_msg(
4949
return constraint_violation_error
5050

5151

52-
def clean_nn_module_stack(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
52+
def clean_nn_module_stack(
53+
graph_module: torch.fx.GraphModule, is_inline_builtin=False
54+
) -> torch.fx.GraphModule:
55+
"""
56+
Clean up nn_module_stack metadata by removing export_root references.
57+
58+
Removes the _export_root module references from nn_module_stack metadata
59+
in graph nodes, which are artifacts from the export process. Fixes two patterns:
60+
61+
1. Keys: Removes "__export_root_" and "__modules['_export_root']_" prefixes
62+
- Normal case: "L__self____export_root_child" -> "L__self__child"
63+
- inline_builtin case: Uses numeric ID strings like "140468831433840"
64+
65+
2. Values: Removes "._export_root" and "._modules['_export_root']" from child names
66+
e.g., "L['self']._export_root.child" -> "L['self'].child"
67+
e.g., "L['self']._modules['_export_root'].child" -> "L['self'].child"
68+
69+
Also removes the root export entry "L__self____export_root" entirely.
70+
71+
Args:
72+
graph_module: The GraphModule to clean up
73+
is_inline_builtin: If True, keys are numeric ID strings and self references
74+
(L['self']) are filtered out
75+
76+
Returns:
77+
The cleaned GraphModule (modified in-place)
78+
"""
5379
for node in graph_module.graph.nodes:
54-
if "nn_module_stack" in node.meta:
55-
nn_module_stack = node.meta["nn_module_stack"].copy()
56-
first_key = next(iter(nn_module_stack.keys()))
57-
if "export_root" in first_key:
58-
del nn_module_stack[first_key]
59-
nn_module_stack_corrected = {}
60-
for k, v in nn_module_stack.items():
61-
k_new = "".join(k.split("__export_root"))
62-
child_name, child_class = v
63-
child_name = child_name.replace("._export_root", "")
64-
nn_module_stack_corrected[k_new] = (child_name, child_class)
65-
node.meta["nn_module_stack"] = nn_module_stack_corrected
80+
if "nn_module_stack" not in node.meta:
81+
continue
82+
83+
nn_module_stack = node.meta["nn_module_stack"].copy()
84+
85+
if "L__self____export_root" in nn_module_stack:
86+
del nn_module_stack["L__self____export_root"]
87+
88+
# Clean up remaining entries
89+
cleaned_stack = {}
90+
for key, (child_name, child_class) in nn_module_stack.items():
91+
# Clean key by removing export_root patterns
92+
clean_key = key.replace("__modules['_export_root']_", "").replace(
93+
"__export_root_", ""
94+
)
95+
96+
# Clean child_name by removing export_root patterns
97+
clean_name = child_name.replace("._modules['_export_root']", "").replace(
98+
"._export_root", ""
99+
)
100+
101+
# Skip self reference for inline builtin case
102+
if is_inline_builtin and clean_name == "L['self']":
103+
continue
104+
105+
cleaned_stack[clean_key] = (clean_name, child_class)
106+
107+
node.meta["nn_module_stack"] = cleaned_stack
108+
66109
return graph_module
67110

68111

@@ -71,7 +114,11 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
71114

72115
# Clean parameter names: L__self____export_root_param -> L__self___param
73116
def clean_name(name) -> str:
74-
return name.replace("__export_root_", "_") if "__export_root_" in name else name
117+
if "____modules___export_root_" in name:
118+
return name.replace("____modules___export_root_", "_")
119+
if "__export_root_" in name:
120+
return name.replace("__export_root_", "_")
121+
return name
75122

76123
# Update get_attr nodes in-place
77124
for node in graph_module.graph.nodes:
@@ -409,7 +456,9 @@ def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
409456
)
410457
transformed_graph.recompile()
411458

412-
clean_nn_module_stack(transformed_graph)
459+
clean_nn_module_stack(
460+
transformed_graph, torch._dynamo.config.inline_inbuilt_nn_modules
461+
)
413462
clean_export_root(transformed_graph)
414463

415464
transformed_graph.meta["module_call_specs"] = module_call_spec

0 commit comments

Comments
 (0)