@@ -66,32 +66,28 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
6666 loss (torch.Tensor): The loss that is used to compute the gradients to the network
6767 parameters.
6868 """
69- # Step parameter only
7069 for i , (param_container , state ) in enumerate (
7170 zip (self .param_containers_groups , self .state_groups ),
7271 ):
7372 flat_params : TupleOfTensors
7473 flat_params , container_treespec = pytree .tree_flatten_as_tuple (param_container ) # type: ignore[arg-type]
74+
7575 if isinstance (state , UninitializedState ):
7676 state = self .impl .init (flat_params )
77- grads = torch .autograd .grad (
78- loss ,
79- flat_params ,
80- create_graph = True ,
81- allow_unused = True ,
82- )
83- updates , new_state = self .impl .update (
84- grads ,
85- state ,
86- params = flat_params ,
87- inplace = False ,
88- )
89- self .state_groups [i ] = new_state
77+
78+ with torch .enable_grad ():
79+ # Step parameters only
80+ grads = torch .autograd .grad (loss , flat_params , create_graph = True , allow_unused = True )
81+
82+ updates , new_state = self .impl .update (grads , state , params = flat_params , inplace = False )
83+
9084 flat_new_params = apply_updates (flat_params , updates , inplace = False )
9185 new_params : ModuleTensorContainers = pytree .tree_unflatten ( # type: ignore[assignment]
9286 container_treespec ,
9387 flat_new_params ,
9488 )
89+
90+ self .state_groups [i ] = new_state
9591 for container , new_param in zip (param_container , new_params ):
9692 container .update (new_param )
9793
0 commit comments