@@ -169,11 +169,20 @@ def compute_gramian(self, output: Tensor) -> Tensor:
169169 """
170170
171171 reshaped_output = output .reshape ([- 1 ])
172- return self ._compute_square_gramian (reshaped_output )
173172
174- def _compute_square_gramian (self , output : Tensor ) -> Tensor :
175173 self ._module_hook_manager .gramian_accumulation_phase = True
176174
175+ try :
176+ square_gramian = self ._compute_square_gramian (reshaped_output )
177+ finally :
178+ # Reset everything that has a state, even if the previous call raised an exception
179+ self ._module_hook_manager .gramian_accumulation_phase = False
180+ self ._gramian_accumulator .reset ()
181+ self ._target_edges .reset ()
182+
183+ return square_gramian
184+
185+ def _compute_square_gramian (self , output : Tensor ) -> Tensor :
177186 leaf_targets = list (self ._target_edges .get_leaf_edges ({get_gradient_edge (output )}))
178187
179188 def differentiation (_grad_output : Tensor ) -> tuple [Tensor , ...]:
@@ -190,9 +199,4 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
190199 # have failed. So gramian is necessarily a valid Tensor here.
191200 gramian = cast (Tensor , self ._gramian_accumulator .gramian )
192201
193- # Reset everything that has a state
194- self ._module_hook_manager .gramian_accumulation_phase = False
195- self ._gramian_accumulator .reset ()
196- self ._target_edges .reset ()
197-
198202 return gramian
0 commit comments