Skip to content

Commit 8f6d637

Browse files
committed
解决冲突问题
1 parent 169b92d commit 8f6d637

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,37 @@ def my_backend(gm, sample_inputs):
3030
**Stable API reference link:**
3131
"""
3232

33-
def avg_pool2d_to_avg_pool2d(self, gm):
33+
def _impl_unstable_to_stable_irfft(self, gm):
34+
def replace_in_graph(graph_mod):
35+
# Register stable implementation on GraphModule, codegen can use self.irfft
36+
try:
37+
setattr(graph_mod, "irfft", torch.fft.irfft)
38+
except Exception:
39+
pass
40+
41+
for node in graph_mod.graph.nodes:
42+
if node.op == "call_function":
43+
# Match for all forms of target names
44+
if "fft_irfft" in str(node.target):
45+
# Directly point target to Python layer function
46+
node.target = torch.fft.irfft
47+
# Validate and recompile the graph
48+
graph_mod.graph.lint()
49+
graph_mod.recompile()
50+
51+
# Process main gm and all nested GraphModules
52+
modules = [gm]
53+
modules += [
54+
m
55+
for _, m in gm.named_modules()
56+
if isinstance(m, torch.fx.GraphModule) and m is not gm
57+
]
58+
for m in modules:
59+
replace_in_graph(m)
60+
61+
return gm
62+
63+
def _impl_unstable_to_stable_avg_pool2d(self, gm):
3464
"""
3565
Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
3666
"""
@@ -52,7 +82,7 @@ def avg_pool2d_to_avg_pool2d(self, gm):
5282

5383
return gm
5484

55-
def fft_rfft_to_rfft(self, gm):
85+
def _impl_unstable_to_stable_rfft(self, gm):
5686
"""
5787
Convert torch._C._fft.fft_rfft to torch.fft.rfft
5888
"""
@@ -75,11 +105,13 @@ def fft_rfft_to_rfft(self, gm):
75105
return gm
76106

77107
def unstable_to_stable(self, gm):
78-
# Convert based on unstable_api environment variable
79-
if self.unstable_api == "torch._C._nn.avg_pool2d":
80-
gm = self.avg_pool2d_to_avg_pool2d(gm)
81-
elif self.unstable_api == "torch._C._fft.fft_rfft":
82-
gm = self.fft_rfft_to_rfft(gm)
108+
methods = (
109+
name
110+
for name in vars(type(self)).keys()
111+
if name.startswith("_impl_unstable_to_stable")
112+
)
113+
for method in methods:
114+
gm = getattr(self, method)(gm)
83115
return gm
84116

85117
def check_unstable_api(self, gm):

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
1919
"""
2020
code = gm.code
2121
# Replace torch._C._nn.avg_pool2d with torch.nn.functional.avg_pool2d
22-
code = re.sub(
23-
r"torch\._C\._nn\.avg_pool2d\(",
24-
"torch.nn.functional.avg_pool2d(",
25-
code,
26-
)
27-
code = re.sub(
28-
r"torch\._C\._fft\.fft_rfft\(",
29-
"torch.fft.rfft(",
30-
code,
31-
)
22+
replacements = [
23+
(r"torch\._C\._nn\.avg_pool2d\(", "torch.nn.functional.avg_pool2d("),
24+
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
25+
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
26+
# Add new rules to this list as needed
27+
]
28+
for pattern, repl in replacements:
29+
code = re.sub(pattern, repl, code)
3230
return code

0 commit comments

Comments
 (0)