Skip to content

Commit e38922f

Browse files
committed
Convert torch._C._set_grad_enabled and torch._C.set_grad_enabled to torch.set_grad_enabled
1 parent c65f7fa commit e38922f

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,28 @@ def _impl_unstable_to_stable_special_logit(self, gm):
155155

156156
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
157157

158-
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
158+
def _impl_unstable_to_stable_set_grad_enabled(self, gm):
159+
"""
160+
Convert torch._C._set_grad_enabled and torch._C.set_grad_enabled to torch.set_grad_enabled
161+
"""
162+
163+
def replace_in_graph(graph_mod):
164+
for node in graph_mod.graph.nodes:
165+
if node.op == "call_function":
166+
if "set_grad_enabled" in str(node.target):
167+
node.target = torch.set_grad_enabled
168+
graph_mod.recompile()
169+
170+
modules = [gm]
171+
modules += [
172+
m
173+
for _, m in gm.named_modules()
174+
if isinstance(m, torch.fx.GraphModule) and m is not gm
175+
]
176+
for m in modules:
177+
replace_in_graph(m)
178+
179+
return gm
159180

160181
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
161182

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2929
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
3030
# replace this line with modification code for task 118 (torch._C._nn.softplus)
3131
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
32-
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
32+
(r"torch\._C\._set_grad_enabled\(", "torch.set_grad_enabled("),
33+
(r"torch\._C\.set_grad_enabled\(", "torch.set_grad_enabled("),
3334
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
3435
# replace this line with modification code for task 123 (torch._C._nn.pad)
3536
# replace this line with modification code for task 125 (torch._C._nn.gelu)

0 commit comments

Comments
 (0)