Skip to content

Commit 7268252

Browse files
authored
Convert torch._C._set_grad_enabled and torch._C.set_grad_enabled to torch.set_grad_enabled (#341)
1 parent 9498e30 commit 7268252

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,28 @@ def _impl_unstable_to_stable_softplus(self, gm):
183183

184184
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
185185

186-
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
186+
def _impl_unstable_to_stable_set_grad_enabled(self, gm):
187+
"""
188+
Convert torch._C._set_grad_enabled and torch._C.set_grad_enabled to torch.set_grad_enabled
189+
"""
190+
191+
def replace_in_graph(graph_mod):
192+
for node in graph_mod.graph.nodes:
193+
if node.op == "call_function":
194+
if "set_grad_enabled" in str(node.target):
195+
node.target = torch.set_grad_enabled
196+
graph_mod.recompile()
197+
198+
modules = [gm]
199+
modules += [
200+
m
201+
for _, m in gm.named_modules()
202+
if isinstance(m, torch.fx.GraphModule) and m is not gm
203+
]
204+
for m in modules:
205+
replace_in_graph(m)
206+
207+
return gm
187208

188209
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
189210

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
143143
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
144144
(r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("),
145145
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
146-
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
146+
(r"torch\._C\._set_grad_enabled\(", "torch.set_grad_enabled("),
147+
(r"torch\._C\.set_grad_enabled\(", "torch.set_grad_enabled("),
147148
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
148149
# replace this line with modification code for task 123 (torch._C._nn.pad)
149150
# replace this line with modification code for task 125 (torch._C._nn.gelu)

0 commit comments

Comments
 (0)