Skip to content

Commit f0d46df

Browse files
committed
Convert torch._C._linalg.linalg_norm to torch.linalg.norm
1 parent de2d8c4 commit f0d46df

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,27 @@ def _impl_unstable_to_stable_special_logit(self, gm):
147147

148148
return gm
149149

150+
def _impl_unstable_to_stable_linalg_norm(self, gm):
151+
"""
152+
Convert torch._C._linalg.linalg_norm to torch.linalg.norm
153+
"""
154+
# Update graph nodes: replace torch._C._linalg.linalg_norm with torch.linalg.norm
155+
issue_nodes = (
156+
node
157+
for node in gm.graph.nodes
158+
if node.op == "call_function"
159+
if hasattr(node.target, "__module__")
160+
if node.target.__module__ == "torch._C._linalg"
161+
if hasattr(node.target, "__name__")
162+
if node.target.__name__ == "linalg_norm"
163+
)
164+
for node in issue_nodes:
165+
node.target = torch.linalg.norm
166+
167+
# Recompile the graph
168+
gm.recompile()
169+
return gm
170+
150171
def unstable_to_stable(self, gm):
151172
methods = (
152173
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
2727
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
28+
(r"torch\._C\._linalg\.linalg_norm\(", "torch.linalg.norm("),
2829
# Add new rules to this list as needed
2930
]
3031
for pattern, repl in replacements:

0 commit comments

Comments
 (0)