Skip to content

Commit 5a640e2

Browse files
committed
英文注释,撤销原始日志文件
1 parent 96785c6 commit 5a640e2

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __call__(self, model):
1212
self.unstable_api = unstable_api
1313

1414
def my_backend(gm, sample_inputs):
15-
gm = self.unstable_to_stable(gm)
15+
gm = self.fft_irfft_to_irfft(gm)
1616
self.check_unstable_api(gm)
1717
return gm.forward
1818

@@ -29,8 +29,34 @@ def my_backend(gm, sample_inputs):
2929
**Stable API reference link:**
3030
"""
3131

32-
def unstable_to_stable(self, gm):
33-
# TODO
32+
def fft_irfft_to_irfft(self, gm):
33+
def replace_in_graph(graph_mod):
34+
# Register stable implementation on GraphModule, codegen can use self.irfft
35+
try:
36+
setattr(graph_mod, "irfft", torch.fft.irfft)
37+
except Exception:
38+
pass
39+
40+
for node in graph_mod.graph.nodes:
41+
if node.op == "call_function":
42+
# Match for all forms of target names
43+
if "fft_irfft" in str(node.target):
44+
# Directly point target to Python layer function
45+
node.target = torch.fft.irfft
46+
# Validate and recompile the graph
47+
graph_mod.graph.lint()
48+
graph_mod.recompile()
49+
50+
# Process main gm and all nested GraphModules
51+
modules = [gm]
52+
modules += [
53+
m
54+
for _, m in gm.named_modules()
55+
if isinstance(m, torch.fx.GraphModule) and m is not gm
56+
]
57+
for m in modules:
58+
replace_in_graph(m)
59+
3460
return gm
3561

3662
def check_unstable_api(self, gm):

0 commit comments

Comments
 (0)