Skip to content

Commit cd08bc3

Browse files
committed
Convert torch._C._nn.gelu to torch.nn.functional.gelu
1 parent c65f7fa commit cd08bc3

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,31 @@ def _impl_unstable_to_stable_special_logit(self, gm):
161161

162162
# replace this line with modification code for task 123 (torch._C._nn.pad)
163163

164-
# replace this line with modification code for task 125 (torch._C._nn.gelu)
164+
def _impl_unstable_to_stable_gelu(self, gm):
165+
"""
166+
Convert torch._C._nn.gelu to torch.nn.functional.gelu
167+
"""
168+
import torch.nn.functional as F
169+
170+
def replace_in_graph(graph_mod):
171+
for node in graph_mod.graph.nodes:
172+
if node.op == "call_function":
173+
if "gelu" in str(node.target) and "torch._C._nn" in str(
174+
node.target
175+
):
176+
node.target = F.gelu
177+
graph_mod.recompile()
178+
179+
modules = [gm]
180+
modules += [
181+
m
182+
for _, m in gm.named_modules()
183+
if isinstance(m, torch.fx.GraphModule) and m is not gm
184+
]
185+
for m in modules:
186+
replace_in_graph(m)
187+
188+
return gm
165189

166190
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
167191

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
3232
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
3333
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
3434
# replace this line with modification code for task 123 (torch._C._nn.pad)
35-
# replace this line with modification code for task 125 (torch._C._nn.gelu)
35+
(r"torch\._C\._nn\.gelu\(", "torch.nn.functional.gelu("),
3636
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
3737
# replace this line with modification code for task 127 (torch._C._nn.linear)
3838
]

0 commit comments

Comments
 (0)