Skip to content

Commit fccd908

Browse files
committed
Implement conversion for torch._C._linalg.linalg_vector_norm to stable API torch.linalg.vector_norm in UnstableToStableBackend class. Update serialization utility to include new replacement rule for vector_norm. Refactor serialization logic for better reusability across the codebase.
1 parent 06b8dc2 commit fccd908

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,38 @@ def _impl_unstable_to_stable_fftn(self, gm):
126126

127127
return gm
128128

129+
def _impl_unstable_to_stable_linalg_vector_norm(self, gm):
130+
"""
131+
Convert torch._C._linalg.linalg_vector_norm to torch.linalg.vector_norm
132+
"""
133+
def replace_in_graph(graph_mod):
134+
# Update graph nodes: replace torch._C._linalg.linalg_vector_norm with torch.linalg.vector_norm
135+
for node in graph_mod.graph.nodes:
136+
if node.op == "call_function":
137+
if (
138+
hasattr(node.target, "__module__")
139+
and hasattr(node.target, "__name__")
140+
and node.target.__module__ == "torch._C._linalg"
141+
and node.target.__name__ == "linalg_vector_norm"
142+
):
143+
node.target = torch.linalg.vector_norm
144+
145+
# Validate and recompile the graph
146+
graph_mod.graph.lint()
147+
graph_mod.recompile()
148+
149+
# Process main gm and all nested GraphModules
150+
modules = [gm]
151+
modules += [
152+
m
153+
for _, m in gm.named_modules()
154+
if isinstance(m, torch.fx.GraphModule) and m is not gm
155+
]
156+
for m in modules:
157+
replace_in_graph(m)
158+
159+
return gm
160+
129161
def unstable_to_stable(self, gm):
130162
methods = (
131163
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2424
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
27+
(r"torch\._C\._linalg\.linalg_vector_norm\(", "torch.linalg.vector_norm("),
2728
# Add new rules to this list as needed
2829
]
2930
for pattern, repl in replacements:

0 commit comments

Comments
 (0)