diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 227878968..1262b1eab 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -6,7 +6,7 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms +from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithoutCrossTerms from ._gramian_utils import movedim_gramian, reshape_gramian from ._jacobian_computer import ( AutogradJacobianComputer, @@ -207,7 +207,7 @@ def _make_gramian_computer(self, module: nn.Module) -> GramianComputer: jacobian_computer = FunctionalJacobianComputer(module) else: jacobian_computer = AutogradJacobianComputer(module) - gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer) + gramian_computer = JacobianBasedGramianComputerWithoutCrossTerms(jacobian_computer) return gramian_computer @@ -293,8 +293,6 @@ def compute_gramian(self, output: Tensor) -> Tensor: self._module_hook_manager.gramian_accumulation_phase.value = False self._gramian_accumulator.reset() self._target_edges.reset() - for gramian_computer in self._gramian_computers.values(): - gramian_computer.reset() unordered_gramian = reshape_gramian(square_gramian, ordered_shape) diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 2bc62f218..e5932aa2a 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from torch import Tensor from torch.utils._pytree import PyTree @@ -15,18 +14,12 @@ def __call__( grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - ) -> Optional[Tensor]: + ) -> Tensor: """Compute what we can for a module and optionally return the gramian if it's ready.""" - def track_forward_call(self) -> None: - """Track that the module's forward was called. Necessary in some implementations.""" - - def reset(self): - """Reset state if any. Necessary in some implementations.""" - class JacobianBasedGramianComputer(GramianComputer, ABC): - def __init__(self, jacobian_computer): + def __init__(self, jacobian_computer: JacobianComputer): self.jacobian_computer = jacobian_computer @staticmethod @@ -34,45 +27,20 @@ def _to_gramian(jacobian: Tensor) -> Tensor: return jacobian @ jacobian.T -class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): +class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer): """ - Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning - the gramian. + Stateful JacobianBasedGramianComputer that directly returning the gramian without considering + cross-terms (except intra-module cross-terms). """ - def __init__(self, jacobian_computer: JacobianComputer): - super().__init__(jacobian_computer) - self.remaining_counter = 0 - self.summed_jacobian: Optional[Tensor] = None - - def reset(self) -> None: - self.remaining_counter = 0 - self.summed_jacobian = None - - def track_forward_call(self) -> None: - self.remaining_counter += 1 - def __call__( self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - ) -> Optional[Tensor]: + ) -> Tensor: """Compute what we can for a module and optionally return the gramian if it's ready.""" jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs) - - if self.summed_jacobian is None: - self.summed_jacobian = jacobian_matrix - else: - self.summed_jacobian += jacobian_matrix - - self.remaining_counter -= 1 - - if self.remaining_counter == 0: - gramian = self._to_gramian(self.summed_jacobian) - del self.summed_jacobian - return gramian - else: - return None + return self._to_gramian(jacobian_matrix) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 082ace69e..8fc39e716 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -123,8 +123,6 @@ def __call__( # require grad return outputs - self.gramian_computer.track_forward_call() - # We only care about running the AutogramNode, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory # efficiency, we select the smallest one (that requires grad). @@ -186,13 +184,12 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple: # For python > 3.10: -> tuple[None, None, None, None, None, *tuple[Tensor, ...]] if ctx.gramian_accumulation_phase: - optional_gramian = ctx.gramian_computer( + gramian = ctx.gramian_computer( ctx.rg_outputs, grad_outputs, ctx.args, ctx.kwargs, ) - if optional_gramian is not None: - ctx.gramian_accumulator.accumulate_gramian(optional_gramian) + ctx.gramian_accumulator.accumulate_gramian(gramian) return None, None, None, None, None, *grad_outputs diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 50135796e..84a13ea02 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -126,6 +126,8 @@ 32, marks=mark.filterwarnings("ignore:There is a performance drop"), ), + (ModuleFactory(ModelAlsoUsingSubmoduleParamsDirectly), 32), + (ModuleFactory(InterModuleParamReuse), 32), (ModuleFactory(FreeParam), 32), (ModuleFactory(NoFreeParam), 32), param(ModuleFactory(Cifar10Model), 16, marks=mark.slow), @@ -183,7 +185,7 @@ def _get_losses_and_params_without_cross_terms( return losses, params -_get_losses_and_params = _get_losses_and_params_with_cross_terms +_get_losses_and_params = _get_losses_and_params_without_cross_terms @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) @@ -222,8 +224,6 @@ def test_compute_gramian_with_weird_modules( "factory", [ ModuleFactory(ModelUsingSubmoduleParamsDirectly), - ModuleFactory(ModelAlsoUsingSubmoduleParamsDirectly), - ModuleFactory(InterModuleParamReuse), ], ) @mark.parametrize("batch_size", [1, 3, 32])