Skip to content

Commit 6431bbf

Browse files
authored
fix(autogram): Reset even when exception is raised (#417)
1 parent acea309 commit 6431bbf

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,20 @@ def compute_gramian(self, output: Tensor) -> Tensor:
169169
"""
170170

171171
reshaped_output = output.reshape([-1])
172-
return self._compute_square_gramian(reshaped_output)
173172

174-
def _compute_square_gramian(self, output: Tensor) -> Tensor:
175173
self._module_hook_manager.gramian_accumulation_phase = True
176174

175+
try:
176+
square_gramian = self._compute_square_gramian(reshaped_output)
177+
finally:
178+
# Reset everything that has a state, even if the previous call raised an exception
179+
self._module_hook_manager.gramian_accumulation_phase = False
180+
self._gramian_accumulator.reset()
181+
self._target_edges.reset()
182+
183+
return square_gramian
184+
185+
def _compute_square_gramian(self, output: Tensor) -> Tensor:
177186
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))
178187

179188
def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
@@ -190,9 +199,4 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
190199
# have failed. So gramian is necessarily a valid Tensor here.
191200
gramian = cast(Tensor, self._gramian_accumulator.gramian)
192201

193-
# Reset everything that has a state
194-
self._module_hook_manager.gramian_accumulation_phase = False
195-
self._gramian_accumulator.reset()
196-
self._target_edges.reset()
197-
198202
return gramian

0 commit comments

Comments
 (0)