Skip to content

Commit 420edfe

Browse files
committed
修正api
1 parent 804a297 commit 420edfe

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import inspect
55
from .graph_compiler_backend import GraphCompilerBackend
6+
from ..fx_graph_serialize_util import serialize_graph_module_to_str
67

78

89
class UnstableToStableBackend(GraphCompilerBackend):
@@ -12,7 +13,7 @@ def __call__(self, model):
1213
self.unstable_api = unstable_api
1314

1415
def my_backend(gm, sample_inputs):
15-
gm = self.fft_irfft_to_irfft(gm)
16+
gm = self.unstable_to_stable(gm)
1617
self.check_unstable_api(gm)
1718
return gm.forward
1819

@@ -59,6 +60,36 @@ def replace_in_graph(graph_mod):
5960

6061
return gm
6162

63+
def avg_pool2d_to_avg_pool2d(self, gm):
64+
"""
65+
Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
66+
"""
67+
import torch.nn.functional as F
68+
69+
# Update graph nodes: replace torch._C._nn.avg_pool2d with F.avg_pool2d
70+
for node in gm.graph.nodes:
71+
if node.op == "call_function":
72+
if (
73+
hasattr(node.target, "__module__")
74+
and hasattr(node.target, "__name__")
75+
and node.target.__module__ == "torch._C._nn"
76+
and node.target.__name__ == "avg_pool2d"
77+
):
78+
node.target = F.avg_pool2d
79+
80+
# Recompile the graph
81+
gm.recompile()
82+
83+
return gm
84+
85+
def unstable_to_stable(self, gm):
86+
# Convert based on unstable_api environment variable
87+
if self.unstable_api == "torch._C._nn.avg_pool2d":
88+
gm = self.avg_pool2d_to_avg_pool2d(gm)
89+
elif self.unstable_api == "torch._C._fft.fft_irfft":
90+
gm = self.fft_irfft_to_irfft(gm)
91+
return gm
92+
6293
def check_unstable_api(self, gm):
6394
"""
6495
Check whether gm contains the API specified in the environment
@@ -70,7 +101,8 @@ def check_unstable_api(self, gm):
70101
Do NOT modify, remove, or bypass this check under any circumstances.
71102
"""
72103

73-
graph_text = gm.code
104+
# Use serialized code to check for unstable APIs
105+
graph_text = serialize_graph_module_to_str(gm)
74106
# Search for the unstable API substring
75107
if self.unstable_api in graph_text:
76108
count = graph_text.count(self.unstable_api)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import re
2+
import torch.fx
3+
4+
5+
def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
6+
"""
7+
Serialize a GraphModule to a string representation, replacing unstable APIs
8+
with their stable counterparts.
9+
10+
This function is used to normalize the code representation of GraphModule
11+
for consistency checks and code generation.
12+
13+
Args:
14+
gm: The GraphModule to serialize.
15+
16+
Returns:
17+
A string representation of the GraphModule code with unstable APIs
18+
replaced by stable ones.
19+
"""
20+
code = gm.code
21+
# Replace torch._C._nn.avg_pool2d with torch.nn.functional.avg_pool2d
22+
code = re.sub(
23+
r"torch\._C\._nn\.avg_pool2d\(",
24+
"torch.nn.functional.avg_pool2d(",
25+
code,
26+
)
27+
code = re.sub(
28+
r"torch\._C\._fft\.fft_irfft\(",
29+
"torch.fft.irfft(",
30+
code,
31+
)
32+
return code

0 commit comments

Comments
 (0)