Skip to content

Commit 7a0717a

Browse files
authored
【Hackathon 9th No.118】Convert torch._C._nn.softplus to torch.nn.functional.softplus (#334)
* Convert torch._C._nn.softplus to torch.nn.functional.softplus * resolve merge conflicts * resolve merge conflicts
1 parent 8fb4ad6 commit 7a0717a

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,32 @@ def _impl_unstable_to_stable_special_logit(self, gm):
154154

155155
# Recompile the graph
156156
gm.recompile()
157-
158157
return gm
159158

160159
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
161160

162161
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
163162

164-
# replace this line with modification code for task 118 (torch._C._nn.softplus)
163+
def _impl_unstable_to_stable_softplus(self, gm):
164+
"""
165+
Convert torch._C._nn.softplus to torch.nn.functional.softplus
166+
"""
167+
import torch.nn.functional as F
168+
169+
issue_nodes = (
170+
node
171+
for node in gm.graph.nodes
172+
if node.op == "call_function"
173+
if hasattr(node.target, "__module__")
174+
if node.target.__module__ == "torch._C._nn"
175+
if hasattr(node.target, "__name__")
176+
if node.target.__name__ == "softplus"
177+
)
178+
for node in issue_nodes:
179+
node.target = F.softplus
180+
181+
gm.recompile()
182+
return gm
165183

166184
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
167185

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
141141
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
142142
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
143143
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
144-
# replace this line with modification code for task 118 (torch._C._nn.softplus)
144+
(r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("),
145145
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
146146
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
147147
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)

0 commit comments

Comments
 (0)