Skip to content

Commit e7c6e03

Browse files
authored
【Hackathon 9th No.115】torch._C._fft.fft_fftn API转换 torch.fft.fftn (#332)
* replace torch._C._fft.fft_fftn with torch.fft.fftn * update code * resolve merge conflicts
1 parent d5ce907 commit e7c6e03

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,28 @@ def _impl_unstable_to_stable_rfft(self, gm):
104104

105105
return gm
106106

107+
def _impl_unstable_to_stable_fftn(self, gm):
108+
"""
109+
Convert torch._C._fft.fft_fftn to torch.fft.fftn
110+
"""
111+
# Update graph nodes: replace torch._C._fft.fft_fftn with torch.fft.fftn
112+
issue_nodes = (
113+
node
114+
for node in gm.graph.nodes
115+
if node.op == "call_function"
116+
if hasattr(node.target, "__module__")
117+
if node.target.__module__ == "torch._C._fft"
118+
if hasattr(node.target, "__name__")
119+
if node.target.__name__ == "fft_fftn"
120+
)
121+
for node in issue_nodes:
122+
node.target = torch.fft.fftn
123+
124+
# Recompile the graph
125+
gm.recompile()
126+
127+
return gm
128+
107129
def unstable_to_stable(self, gm):
108130
methods = (
109131
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2323
(r"torch\._C\._nn\.avg_pool2d\(", "torch.nn.functional.avg_pool2d("),
2424
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
26+
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
2627
# Add new rules to this list as needed
2728
]
2829
for pattern, repl in replacements:

0 commit comments

Comments
 (0)