Skip to content

Commit 82ae79a

Browse files
committed
refactor: extract GraphModule serialization logic to fx_graph_serialize_util
- Create fx_graph_serialize_util.py with serialize_graph_module_to_str function - Move unstable API replacement logic from unstable_to_stable_backend to the new utility - Update unstable_to_stable_backend to use serialize_graph_module_to_str - Update extractor.py to use serialize_graph_module_to_str for code serialization - This refactoring makes the serialization logic reusable across the codebase
1 parent 99ea131 commit 82ae79a

File tree

3 files changed

+34
-18
lines changed

3 files changed

+34
-18
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import inspect
55
from .graph_compiler_backend import GraphCompilerBackend
6+
from ..fx_graph_serialize_util import serialize_graph_module_to_str
67

78

89
class UnstableToStableBackend(GraphCompilerBackend):
@@ -34,7 +35,6 @@ def avg_pool2d_to_avg_pool2d(self, gm):
3435
Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
3536
"""
3637
import torch.nn.functional as F
37-
import re
3838

3939
# Update graph nodes: replace torch._C._nn.avg_pool2d with F.avg_pool2d
4040
for node in gm.graph.nodes:
@@ -50,19 +50,6 @@ def avg_pool2d_to_avg_pool2d(self, gm):
5050
# Recompile the graph
5151
gm.recompile()
5252

53-
# Replace in code string for check_unstable_api
54-
# Since torch._C._nn.avg_pool2d and F.avg_pool2d are the same object,
55-
# the generated code will still show torch._C._nn.avg_pool2d
56-
# So we need to replace it in the code string
57-
code = gm.code
58-
modified_code = re.sub(
59-
r"torch\._C\._nn\.avg_pool2d\(",
60-
"torch.nn.functional.avg_pool2d(",
61-
code,
62-
)
63-
# Store modified code for check_unstable_api to use
64-
gm._code_for_check = modified_code
65-
6653
return gm
6754

6855
def unstable_to_stable(self, gm):
@@ -82,8 +69,8 @@ def check_unstable_api(self, gm):
8269
Do NOT modify, remove, or bypass this check under any circumstances.
8370
"""
8471

85-
# Use modified code if available (from conversion), otherwise use original code
86-
graph_text = getattr(gm, "_code_for_check", None) or gm.code
72+
# Use serialized code to check for unstable APIs
73+
graph_text = serialize_graph_module_to_str(gm)
8774
# Search for the unstable API substring
8875
if self.unstable_api in graph_text:
8976
count = graph_text.count(self.unstable_api)

graph_net/torch/extractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55
from typing import Union, Callable
66
from . import utils
7+
from .fx_graph_serialize_util import serialize_graph_module_to_str
78

89
torch._dynamo.config.capture_scalar_outputs = True
910
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -89,9 +90,9 @@ def try_rename_placeholder(node):
8990
assert input_idx == len(sample_inputs)
9091
if self.mut_graph_codes is not None:
9192
assert isinstance(self.mut_graph_codes, list)
92-
self.mut_graph_codes.append(gm.code)
93+
self.mut_graph_codes.append(serialize_graph_module_to_str(gm))
9394
# 3. Generate and save model code
94-
base_code = gm.code
95+
base_code = serialize_graph_module_to_str(gm)
9596
# gm.graph.print_tabular()
9697
write_code = utils.apply_templates(base_code)
9798
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import re
2+
import torch.fx
3+
4+
5+
def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
6+
"""
7+
Serialize a GraphModule to a string representation, replacing unstable APIs
8+
with their stable counterparts.
9+
10+
This function is used to normalize the code representation of GraphModule
11+
for consistency checks and code generation.
12+
13+
Args:
14+
gm: The GraphModule to serialize.
15+
16+
Returns:
17+
A string representation of the GraphModule code with unstable APIs
18+
replaced by stable ones.
19+
"""
20+
code = gm.code
21+
# Replace torch._C._nn.avg_pool2d with torch.nn.functional.avg_pool2d
22+
code = re.sub(
23+
r"torch\._C\._nn\.avg_pool2d\(",
24+
"torch.nn.functional.avg_pool2d(",
25+
code,
26+
)
27+
return code
28+

0 commit comments

Comments
 (0)