Skip to content

Commit 740c3d8

Browse files
authored
【Hackathon 9th No.113】torch._C._fft.fft_irfft API转换 torch.fft.irfft (#323)
* 英文注释,撤销原始日志文件 * 修正api
1 parent 5b9272a commit 740c3d8

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,36 @@ def my_backend(gm, sample_inputs):
3030
**Stable API reference link:**
3131
"""
3232

33+
def fft_irfft_to_irfft(self, gm):
34+
def replace_in_graph(graph_mod):
35+
# Register stable implementation on GraphModule, codegen can use self.irfft
36+
try:
37+
setattr(graph_mod, "irfft", torch.fft.irfft)
38+
except Exception:
39+
pass
40+
41+
for node in graph_mod.graph.nodes:
42+
if node.op == "call_function":
43+
# Match for all forms of target names
44+
if "fft_irfft" in str(node.target):
45+
# Directly point target to Python layer function
46+
node.target = torch.fft.irfft
47+
# Validate and recompile the graph
48+
graph_mod.graph.lint()
49+
graph_mod.recompile()
50+
51+
# Process main gm and all nested GraphModules
52+
modules = [gm]
53+
modules += [
54+
m
55+
for _, m in gm.named_modules()
56+
if isinstance(m, torch.fx.GraphModule) and m is not gm
57+
]
58+
for m in modules:
59+
replace_in_graph(m)
60+
61+
return gm
62+
3363
def avg_pool2d_to_avg_pool2d(self, gm):
3464
"""
3565
Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
@@ -56,6 +86,8 @@ def unstable_to_stable(self, gm):
5686
# Convert based on unstable_api environment variable
5787
if self.unstable_api == "torch._C._nn.avg_pool2d":
5888
gm = self.avg_pool2d_to_avg_pool2d(gm)
89+
elif self.unstable_api == "torch._C._fft.fft_irfft":
90+
gm = self.fft_irfft_to_irfft(gm)
5991
return gm
6092

6193
def check_unstable_api(self, gm):

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,9 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2424
"torch.nn.functional.avg_pool2d(",
2525
code,
2626
)
27+
code = re.sub(
28+
r"torch\._C\._fft\.fft_irfft\(",
29+
"torch.fft.irfft(",
30+
code,
31+
)
2732
return code

0 commit comments

Comments
 (0)