@@ -49,20 +49,63 @@ def post_process_error_msg(
49
49
return constraint_violation_error
50
50
51
51
52
- def clean_nn_module_stack (graph_module : torch .fx .GraphModule ) -> torch .fx .GraphModule :
52
+ def clean_nn_module_stack (
53
+ graph_module : torch .fx .GraphModule , is_inline_builtin = False
54
+ ) -> torch .fx .GraphModule :
55
+ """
56
+ Clean up nn_module_stack metadata by removing export_root references.
57
+
58
+ Removes the _export_root module references from nn_module_stack metadata
59
+ in graph nodes, which are artifacts from the export process. Fixes two patterns:
60
+
61
+ 1. Keys: Removes "__export_root_" and "__modules['_export_root']_" prefixes
62
+ - Normal case: "L__self____export_root_child" -> "L__self__child"
63
+ - inline_builtin case: Uses numeric ID strings like "140468831433840"
64
+
65
+ 2. Values: Removes "._export_root" and "._modules['_export_root']" from child names
66
+ e.g., "L['self']._export_root.child" -> "L['self'].child"
67
+ e.g., "L['self']._modules['_export_root'].child" -> "L['self'].child"
68
+
69
+ Also removes the root export entry "L__self____export_root" entirely.
70
+
71
+ Args:
72
+ graph_module: The GraphModule to clean up
73
+ is_inline_builtin: If True, keys are numeric ID strings and self references
74
+ (L['self']) are filtered out
75
+
76
+ Returns:
77
+ The cleaned GraphModule (modified in-place)
78
+ """
53
79
for node in graph_module .graph .nodes :
54
- if "nn_module_stack" in node .meta :
55
- nn_module_stack = node .meta ["nn_module_stack" ].copy ()
56
- first_key = next (iter (nn_module_stack .keys ()))
57
- if "export_root" in first_key :
58
- del nn_module_stack [first_key ]
59
- nn_module_stack_corrected = {}
60
- for k , v in nn_module_stack .items ():
61
- k_new = "" .join (k .split ("__export_root" ))
62
- child_name , child_class = v
63
- child_name = child_name .replace ("._export_root" , "" )
64
- nn_module_stack_corrected [k_new ] = (child_name , child_class )
65
- node .meta ["nn_module_stack" ] = nn_module_stack_corrected
80
+ if "nn_module_stack" not in node .meta :
81
+ continue
82
+
83
+ nn_module_stack = node .meta ["nn_module_stack" ].copy ()
84
+
85
+ if "L__self____export_root" in nn_module_stack :
86
+ del nn_module_stack ["L__self____export_root" ]
87
+
88
+ # Clean up remaining entries
89
+ cleaned_stack = {}
90
+ for key , (child_name , child_class ) in nn_module_stack .items ():
91
+ # Clean key by removing export_root patterns
92
+ clean_key = key .replace ("__modules['_export_root']_" , "" ).replace (
93
+ "__export_root_" , ""
94
+ )
95
+
96
+ # Clean child_name by removing export_root patterns
97
+ clean_name = child_name .replace ("._modules['_export_root']" , "" ).replace (
98
+ "._export_root" , ""
99
+ )
100
+
101
+ # Skip self reference for inline builtin case
102
+ if is_inline_builtin and clean_name == "L['self']" :
103
+ continue
104
+
105
+ cleaned_stack [clean_key ] = (clean_name , child_class )
106
+
107
+ node .meta ["nn_module_stack" ] = cleaned_stack
108
+
66
109
return graph_module
67
110
68
111
@@ -71,7 +114,11 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
71
114
72
115
# Clean parameter names: L__self____export_root_param -> L__self___param
73
116
def clean_name (name ) -> str :
74
- return name .replace ("__export_root_" , "_" ) if "__export_root_" in name else name
117
+ if "____modules___export_root_" in name :
118
+ return name .replace ("____modules___export_root_" , "_" )
119
+ if "__export_root_" in name :
120
+ return name .replace ("__export_root_" , "_" )
121
+ return name
75
122
76
123
# Update get_attr nodes in-place
77
124
for node in graph_module .graph .nodes :
@@ -409,7 +456,9 @@ def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
409
456
)
410
457
transformed_graph .recompile ()
411
458
412
- clean_nn_module_stack (transformed_graph )
459
+ clean_nn_module_stack (
460
+ transformed_graph , torch ._dynamo .config .inline_inbuilt_nn_modules
461
+ )
413
462
clean_export_root (transformed_graph )
414
463
415
464
transformed_graph .meta ["module_call_specs" ] = module_call_spec
0 commit comments