Skip to content

Commit 6c5304a

Browse files
committed
chore(optim): wrap torch.autograd.grad() with torch.enable_grad() context
1 parent b3f570c commit 6c5304a

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

torchopt/optim/func/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def step(
8787
if inplace is None:
8888
inplace = self.inplace
8989

90-
# Step parameter only
91-
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
90+
with torch.enable_grad():
91+
# Step parameters only
92+
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
93+
9294
updates, self.optim_state = self.impl.update(
9395
grads,
9496
self.optim_state,

torchopt/optim/meta/base.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,32 +66,28 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
6666
loss (torch.Tensor): The loss that is used to compute the gradients to the network
6767
parameters.
6868
"""
69-
# Step parameter only
7069
for i, (param_container, state) in enumerate(
7170
zip(self.param_containers_groups, self.state_groups),
7271
):
7372
flat_params: TupleOfTensors
7473
flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type]
74+
7575
if isinstance(state, UninitializedState):
7676
state = self.impl.init(flat_params)
77-
grads = torch.autograd.grad(
78-
loss,
79-
flat_params,
80-
create_graph=True,
81-
allow_unused=True,
82-
)
83-
updates, new_state = self.impl.update(
84-
grads,
85-
state,
86-
params=flat_params,
87-
inplace=False,
88-
)
89-
self.state_groups[i] = new_state
77+
78+
with torch.enable_grad():
79+
# Step parameters only
80+
grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True)
81+
82+
updates, new_state = self.impl.update(grads, state, params=flat_params, inplace=False)
83+
9084
flat_new_params = apply_updates(flat_params, updates, inplace=False)
9185
new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment]
9286
container_treespec,
9387
flat_new_params,
9488
)
89+
90+
self.state_groups[i] = new_state
9591
for container, new_param in zip(param_container, new_params):
9692
container.update(new_param)
9793

0 commit comments

Comments
 (0)