Skip to content

Commit 34733d4

Browse files
committed
Convert torch._C._nn.one_hot to torch.nn.functional.one_hot
1 parent e7c6e03 commit 34733d4

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,29 @@ def _impl_unstable_to_stable_fftn(self, gm):
126126

127127
return gm
128128

129+
def _impl_unstable_to_stable_one_hot(self, gm):
130+
"""
131+
Convert torch._C._nn.one_hot to torch.nn.functional.one_hot
132+
"""
133+
import torch.nn.functional as F
134+
135+
issue_nodes = (
136+
node
137+
for node in gm.graph.nodes
138+
if node.op == "call_function"
139+
if hasattr(node.target, "__module__")
140+
if node.target.__module__ == "torch._C._nn"
141+
if hasattr(node.target, "__name__")
142+
if node.target.__name__ == "one_hot"
143+
)
144+
for node in issue_nodes:
145+
node.target = F.one_hot
146+
147+
# Recompile the graph
148+
gm.recompile()
149+
150+
return gm
151+
129152
def unstable_to_stable(self, gm):
130153
methods = (
131154
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2424
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
27+
(r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("),
2728
# Add new rules to this list as needed
2829
]
2930
for pattern, repl in replacements:

0 commit comments

Comments
 (0)