Skip to content

Commit abae03e

Browse files
committed
Convert torch._C._nn.gelu to torch.nn.functional.gelu
1 parent ea928bb commit abae03e

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,31 @@ def replace_in_graph(graph_mod):
292292

293293
return gm
294294

295-
# replace this line with modification code for task 125 (torch._C._nn.gelu)
295+
def _impl_unstable_to_stable_gelu(self, gm):
296+
"""
297+
Convert torch._C._nn.gelu to torch.nn.functional.gelu
298+
"""
299+
import torch.nn.functional as F
300+
301+
def replace_in_graph(graph_mod):
302+
for node in graph_mod.graph.nodes:
303+
if node.op == "call_function":
304+
if "gelu" in str(node.target) and "torch._C._nn" in str(
305+
node.target
306+
):
307+
node.target = F.gelu
308+
graph_mod.recompile()
309+
310+
modules = [gm]
311+
modules += [
312+
m
313+
for _, m in gm.named_modules()
314+
if isinstance(m, torch.fx.GraphModule) and m is not gm
315+
]
316+
for m in modules:
317+
replace_in_graph(m)
318+
319+
return gm
296320

297321
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
298322

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
147147
(r"torch\._C\.set_grad_enabled\(", "torch.set_grad_enabled("),
148148
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
149149
(r"torch\._C\._nn\.pad\(", "torch.nn.functional.pad("),
150-
# replace this line with modification code for task 125 (torch._C._nn.gelu)
150+
(r"torch\._C\._nn\.gelu\(", "torch.nn.functional.gelu("),
151151
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
152152
(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),
153153
]

0 commit comments

Comments
 (0)