Skip to content

Commit 78704e5

Browse files
committed
Convert torch._C._nn.pad to torch.nn.functional.pad
1 parent c65f7fa commit 78704e5

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,29 @@ def _impl_unstable_to_stable_special_logit(self, gm):
159159

160160
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
161161

162-
# replace this line with modification code for task 123 (torch._C._nn.pad)
162+
def _impl_unstable_to_stable_pad(self, gm):
163+
"""
164+
Convert torch._C._nn.pad to torch.nn.functional.pad
165+
"""
166+
import torch.nn.functional as F
167+
168+
def replace_in_graph(graph_mod):
169+
for node in graph_mod.graph.nodes:
170+
if node.op == "call_function":
171+
if "pad" in str(node.target) and "torch._C._nn" in str(node.target):
172+
node.target = F.pad
173+
graph_mod.recompile()
174+
175+
modules = [gm]
176+
modules += [
177+
m
178+
for _, m in gm.named_modules()
179+
if isinstance(m, torch.fx.GraphModule) and m is not gm
180+
]
181+
for m in modules:
182+
replace_in_graph(m)
183+
184+
return gm
163185

164186
# replace this line with modification code for task 125 (torch._C._nn.gelu)
165187

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
3131
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
3232
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
3333
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
34-
# replace this line with modification code for task 123 (torch._C._nn.pad)
34+
(r"torch\._C\._nn\.pad\(", "torch.nn.functional.pad("),
3535
# replace this line with modification code for task 125 (torch._C._nn.gelu)
3636
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
3737
# replace this line with modification code for task 127 (torch._C._nn.linear)

0 commit comments

Comments
 (0)