Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithoutCrossTerms
from ._gramian_utils import movedim_gramian, reshape_gramian
from ._jacobian_computer import (
AutogradJacobianComputer,
Expand Down Expand Up @@ -207,7 +207,7 @@ def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
jacobian_computer = FunctionalJacobianComputer(module)
else:
jacobian_computer = AutogradJacobianComputer(module)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
gramian_computer = JacobianBasedGramianComputerWithoutCrossTerms(jacobian_computer)

return gramian_computer

Expand Down Expand Up @@ -293,8 +293,6 @@ def compute_gramian(self, output: Tensor) -> Tensor:
self._module_hook_manager.gramian_accumulation_phase.value = False
self._gramian_accumulator.reset()
self._target_edges.reset()
for gramian_computer in self._gramian_computers.values():
gramian_computer.reset()

unordered_gramian = reshape_gramian(square_gramian, ordered_shape)

Expand Down
46 changes: 7 additions & 39 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import Optional

from torch import Tensor
from torch.utils._pytree import PyTree
Expand All @@ -15,64 +14,33 @@ def __call__(
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Optional[Tensor]:
) -> Tensor:
"""Compute what we can for a module and optionally return the gramian if it's ready."""

def track_forward_call(self) -> None:
"""Track that the module's forward was called. Necessary in some implementations."""

def reset(self):
"""Reset state if any. Necessary in some implementations."""


class JacobianBasedGramianComputer(GramianComputer, ABC):
def __init__(self, jacobian_computer):
def __init__(self, jacobian_computer: JacobianComputer):
self.jacobian_computer = jacobian_computer

@staticmethod
def _to_gramian(jacobian: Tensor) -> Tensor:
return jacobian @ jacobian.T


class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer):
"""
Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning
the gramian.
Stateful JacobianBasedGramianComputer that directly returning the gramian without considering
cross-terms (except intra-module cross-terms).
"""

def __init__(self, jacobian_computer: JacobianComputer):
super().__init__(jacobian_computer)
self.remaining_counter = 0
self.summed_jacobian: Optional[Tensor] = None

def reset(self) -> None:
self.remaining_counter = 0
self.summed_jacobian = None

def track_forward_call(self) -> None:
self.remaining_counter += 1

def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Optional[Tensor]:
) -> Tensor:
"""Compute what we can for a module and optionally return the gramian if it's ready."""

jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)

if self.summed_jacobian is None:
self.summed_jacobian = jacobian_matrix
else:
self.summed_jacobian += jacobian_matrix

self.remaining_counter -= 1

if self.remaining_counter == 0:
gramian = self._to_gramian(self.summed_jacobian)
del self.summed_jacobian
return gramian
else:
return None
return self._to_gramian(jacobian_matrix)
7 changes: 2 additions & 5 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def __call__(
# require grad
return outputs

self.gramian_computer.track_forward_call()

# We only care about running the AutogramNode, so we need one of its child
# edges (the edges of the original outputs of the model) as target. For memory
# efficiency, we select the smallest one (that requires grad).
Expand Down Expand Up @@ -186,13 +184,12 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple:
# For python > 3.10: -> tuple[None, None, None, None, None, *tuple[Tensor, ...]]

if ctx.gramian_accumulation_phase:
optional_gramian = ctx.gramian_computer(
gramian = ctx.gramian_computer(
ctx.rg_outputs,
grad_outputs,
ctx.args,
ctx.kwargs,
)
if optional_gramian is not None:
ctx.gramian_accumulator.accumulate_gramian(optional_gramian)
ctx.gramian_accumulator.accumulate_gramian(gramian)

return None, None, None, None, None, *grad_outputs
6 changes: 3 additions & 3 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
32,
marks=mark.filterwarnings("ignore:There is a performance drop"),
),
(ModuleFactory(ModelAlsoUsingSubmoduleParamsDirectly), 32),
(ModuleFactory(InterModuleParamReuse), 32),
(ModuleFactory(FreeParam), 32),
(ModuleFactory(NoFreeParam), 32),
param(ModuleFactory(Cifar10Model), 16, marks=mark.slow),
Expand Down Expand Up @@ -183,7 +185,7 @@ def _get_losses_and_params_without_cross_terms(
return losses, params


_get_losses_and_params = _get_losses_and_params_with_cross_terms
_get_losses_and_params = _get_losses_and_params_without_cross_terms


@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS)
Expand Down Expand Up @@ -222,8 +224,6 @@ def test_compute_gramian_with_weird_modules(
"factory",
[
ModuleFactory(ModelUsingSubmoduleParamsDirectly),
ModuleFactory(ModelAlsoUsingSubmoduleParamsDirectly),
ModuleFactory(InterModuleParamReuse),
],
)
@mark.parametrize("batch_size", [1, 3, 32])
Expand Down
Loading