Skip to content

Commit 05f323c

Browse files
committed
updata code
1 parent 96faeae commit 05f323c

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,17 @@ def fft_rfft_to_rfft(self, gm):
5757
Convert torch._C._fft.fft_rfft to torch.fft.rfft
5858
"""
5959
# Update graph nodes: replace torch._C._fft.fft_rfft with torch.fft.rfft
60-
for node in gm.graph.nodes:
61-
if node.op == "call_function":
62-
if (
63-
hasattr(node.target, "__module__")
64-
and hasattr(node.target, "__name__")
65-
and node.target.__module__ == "torch._C._fft"
66-
and node.target.__name__ == "fft_rfft"
67-
):
68-
node.target = torch.fft.rfft
60+
issue_nodes = (
61+
node
62+
for node in gm.graph.nodes
63+
if node.op == "call_function"
64+
if hasattr(node.target, "__module__")
65+
if node.target.__module__ == "torch._C._fft"
66+
if hasattr(node.target, "__name__")
67+
if node.target.__name__ == "fft_rfft"
68+
)
69+
for node in issue_nodes:
70+
node.target = torch.fft.rfft
6971

7072
# Recompile the graph
7173
gm.recompile()

0 commit comments

Comments
 (0)