Skip to content

Commit 63084a7

Browse files
committed
resolve merge conflicts
1 parent 992b765 commit 63084a7

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,27 @@ def _impl_unstable_to_stable_one_hot(self, gm):
149149

150150
return gm
151151

152+
def _impl_unstable_to_stable_special_logit(self, gm):
153+
"""
154+
Convert torch._C._special.special_logit to torch.special.logit
155+
"""
156+
issue_nodes = (
157+
node
158+
for node in gm.graph.nodes
159+
if node.op == "call_function"
160+
if hasattr(node.target, "__module__")
161+
if node.target.__module__ == "torch._C._special"
162+
if hasattr(node.target, "__name__")
163+
if node.target.__name__ == "special_logit"
164+
)
165+
for node in issue_nodes:
166+
node.target = torch.special.logit
167+
168+
# Recompile the graph
169+
gm.recompile()
170+
171+
return gm
172+
152173
def unstable_to_stable(self, gm):
153174
methods = (
154175
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
2727
(r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("),
28+
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
2829
# Add new rules to this list as needed
2930
]
3031
for pattern, repl in replacements:

0 commit comments

Comments
 (0)