File tree Expand file tree Collapse file tree 2 files changed +26
-2
lines changed
Expand file tree Collapse file tree 2 files changed +26
-2
lines changed Original file line number Diff line number Diff line change @@ -292,7 +292,31 @@ def replace_in_graph(graph_mod):
292292
293293 return gm
294294
295- # replace this line with modification code for task 125 (torch._C._nn.gelu)
295+ def _impl_unstable_to_stable_gelu (self , gm ):
296+ """
297+ Convert torch._C._nn.gelu to torch.nn.functional.gelu
298+ """
299+ import torch .nn .functional as F
300+
301+ def replace_in_graph (graph_mod ):
302+ for node in graph_mod .graph .nodes :
303+ if node .op == "call_function" :
304+ if "gelu" in str (node .target ) and "torch._C._nn" in str (
305+ node .target
306+ ):
307+ node .target = F .gelu
308+ graph_mod .recompile ()
309+
310+ modules = [gm ]
311+ modules += [
312+ m
313+ for _ , m in gm .named_modules ()
314+ if isinstance (m , torch .fx .GraphModule ) and m is not gm
315+ ]
316+ for m in modules :
317+ replace_in_graph (m )
318+
319+ return gm
296320
297321 # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
298322
Original file line number Diff line number Diff line change @@ -147,7 +147,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
147147 (r"torch\._C\.set_grad_enabled\(" , "torch.set_grad_enabled(" ),
148148 # replace this line with modification code for task 122 (torch._C._log_api_usage_once)
149149 (r"torch\._C\._nn\.pad\(" , "torch.nn.functional.pad(" ),
150- # replace this line with modification code for task 125 ( torch._C._nn.gelu)
150+ ( r" torch\ ._C\ ._nn\ .gelu\(" , "torch.nn.functional.gelu(" ),
151151 # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
152152 (r"torch\._C\._nn\.linear\(" , "torch.nn.functional.linear(" ),
153153 ]
You can’t perform that action at this time.
0 commit comments