Skip to content

Commit 9b02221

Browse files
committed
resolve merge conflicts
1 parent 6af81a8 commit 9b02221

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,26 @@ def _impl_unstable_to_stable_softplus(self, gm):
147147
gm.recompile()
148148
return gm
149149

150+
def _impl_unstable_to_stable_special_logit(self, gm):
151+
"""
152+
Convert torch._C._special.special_logit to torch.special.logit
153+
"""
154+
issue_nodes = (
155+
node
156+
for node in gm.graph.nodes
157+
if node.op == "call_function"
158+
if hasattr(node.target, "__module__")
159+
if node.target.__module__ == "torch._C._special"
160+
if hasattr(node.target, "__name__")
161+
if node.target.__name__ == "special_logit"
162+
)
163+
for node in issue_nodes:
164+
node.target = torch.special.logit
165+
166+
# Recompile the graph
167+
gm.recompile()
168+
return gm
169+
150170
def unstable_to_stable(self, gm):
151171
methods = (
152172
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
2727
(r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("),
28+
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
2829
# Add new rules to this list as needed
2930
]
3031
for pattern, repl in replacements:

0 commit comments

Comments
 (0)