Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithoutCrossTerms
from ._gramian_utils import movedim_gramian, reshape_gramian
from ._jacobian_computer import (
AutogradJacobianComputer,
Expand Down Expand Up @@ -207,7 +207,7 @@ def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
jacobian_computer = FunctionalJacobianComputer(module)
else:
jacobian_computer = AutogradJacobianComputer(module)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
gramian_computer = JacobianBasedGramianComputerWithoutCrossTerms(jacobian_computer)

return gramian_computer

Expand Down
35 changes: 5 additions & 30 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,20 @@ def reset(self):


class JacobianBasedGramianComputer(GramianComputer, ABC):
def __init__(self, jacobian_computer):
def __init__(self, jacobian_computer: JacobianComputer):
self.jacobian_computer = jacobian_computer

@staticmethod
def _to_gramian(jacobian: Tensor) -> Tensor:
return jacobian @ jacobian.T


class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer):
"""
Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning
the gramian.
Stateful JacobianBasedGramianComputer that directly returning the gramian without considering
cross-terms (except intra-module cross-terms).
"""

def __init__(self, jacobian_computer: JacobianComputer):
super().__init__(jacobian_computer)
self.remaining_counter = 0
self.summed_jacobian: Optional[Tensor] = None

def reset(self) -> None:
self.remaining_counter = 0
self.summed_jacobian = None

def track_forward_call(self) -> None:
self.remaining_counter += 1

def __call__(
self,
rg_outputs: tuple[Tensor, ...],
Expand All @@ -62,17 +50,4 @@ def __call__(
"""Compute what we can for a module and optionally return the gramian if it's ready."""

jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)

if self.summed_jacobian is None:
self.summed_jacobian = jacobian_matrix
else:
self.summed_jacobian += jacobian_matrix

self.remaining_counter -= 1

if self.remaining_counter == 0:
gramian = self._to_gramian(self.summed_jacobian)
del self.summed_jacobian
return gramian
else:
return None
return self._to_gramian(jacobian_matrix)
6 changes: 3 additions & 3 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
32,
marks=mark.filterwarnings("ignore:There is a performance drop"),
),
(ModuleFactory(ModelAlsoUsingSubmoduleParamsDirectly), 32),
(ModuleFactory(InterModuleParamReuse), 32),
(ModuleFactory(FreeParam), 32),
(ModuleFactory(NoFreeParam), 32),
param(ModuleFactory(Cifar10Model), 16, marks=mark.slow),
Expand Down Expand Up @@ -183,7 +185,7 @@ def _get_losses_and_params_without_cross_terms(
return losses, params


_get_losses_and_params = _get_losses_and_params_with_cross_terms
_get_losses_and_params = _get_losses_and_params_without_cross_terms


@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS)
Expand Down Expand Up @@ -222,8 +224,6 @@ def test_compute_gramian_with_weird_modules(
"factory",
[
ModuleFactory(ModelUsingSubmoduleParamsDirectly),
ModuleFactory(ModelAlsoUsingSubmoduleParamsDirectly),
ModuleFactory(InterModuleParamReuse),
],
)
@mark.parametrize("batch_size", [1, 3, 32])
Expand Down
Loading