diff --git a/pyproject.toml b/pyproject.toml index cd0822b04..a83fa44d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ ] requires-python = ">=3.10" dependencies = [ - "torch>=2.0.0", + "torch[opt-einsum]>=2.0.0", "quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked "numpy>=1.21.0", # Does not work before 1.21 "qpsolvers>=1.0.1", # Does not work before 1.0.1 diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 227878968..bf88c0a8b 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -6,7 +6,11 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms +from ._gramian_computer import ( + GramianComputer, + JacobianBasedGramianComputerWithCrossTerms, + LinearBasedGramianComputer, +) from ._gramian_utils import movedim_gramian, reshape_gramian from ._jacobian_computer import ( AutogradJacobianComputer, @@ -203,11 +207,16 @@ def _hook_module_recursively(self, module: nn.Module) -> None: def _make_gramian_computer(self, module: nn.Module) -> GramianComputer: jacobian_computer: JacobianComputer + gramian_computer: GramianComputer if self._batch_dim is not None: - jacobian_computer = FunctionalJacobianComputer(module) + if isinstance(module, nn.Linear): + gramian_computer = LinearBasedGramianComputer(module) + else: + jacobian_computer = FunctionalJacobianComputer(module) + gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer) else: jacobian_computer = AutogradJacobianComputer(module) - gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer) + gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer) return gramian_computer @@ -321,7 +330,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: non_batch_dim_len = output.shape[0] identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype) ones = torch.ones_like(output[0]) - jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones) + jac_output = torch.einsum(identity_matrix, [0, 1], ones, [...], [0, 1, ...]) _ = vmap(differentiation)(jac_output) else: diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 2bc62f218..ea495c634 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -1,9 +1,12 @@ from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Optional -from torch import Tensor +import torch +from torch import Tensor, nn from torch.utils._pytree import PyTree +from torchjd.autogram._gramian_utils import reshape_gramian from torchjd.autogram._jacobian_computer import JacobianComputer @@ -76,3 +79,118 @@ def __call__( return gramian else: return None + + +class ModuleBasedGramianComputer(GramianComputer, ABC): + def __init__(self, module: nn.Module): + self.module = module + + def __call__( + self, + rg_outputs: tuple[Tensor, ...], + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + ) -> Tensor: + gramian = ComputeGramian.apply( + self._compute_gramian, rg_outputs, grad_outputs, args, kwargs + ) + return gramian + + @abstractmethod + def _compute_gramian( + self, + rg_outputs: tuple[Tensor, ...], + jac_outputs1: tuple[Tensor, ...], + jac_outputs2: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + ) -> Tensor: + """ + If G is the Gramian of the Jacobian of the model's output w.r.t. the parameters, and J1, J2 + are the jac_outputs (Jacobian of losses w.r.t. outputs), then this should compute the matrix + J1 @ G @ J2.T + """ + + +class ComputeGramian(torch.autograd.Function): + @staticmethod + def forward( + compute_gramian_fn: Callable[ + [ + tuple[Tensor, ...], + tuple[Tensor, ...], + tuple[Tensor, ...], + tuple[PyTree, ...], + dict[str, PyTree], + ], + Tensor, + ], + rg_outputs: tuple[Tensor, ...], + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + ) -> Tensor: + # There is no non-batched dimension + gramian = compute_gramian_fn(rg_outputs, grad_outputs, grad_outputs, args, kwargs) + return gramian + + @staticmethod + def vmap( + _, + in_dims: tuple[None, None, tuple[int, ...], None, None], + compute_gramian_fn: Callable, + rg_outputs: tuple[Tensor, ...], + jac_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + ) -> tuple[Tensor, None]: + # There is a non-batched dimension + generalized_gramian = torch.vmap( + torch.vmap( + compute_gramian_fn, + in_dims=(None, in_dims[2], None, None, None), + out_dims=0, + ), + in_dims=(None, None, in_dims[2], None, None), + out_dims=-1, + )(rg_outputs, jac_outputs, jac_outputs, args, kwargs) + shape = generalized_gramian.shape + gramian = reshape_gramian(generalized_gramian, [shape[0] * shape[1]]) + return gramian, None + + @staticmethod + def setup_context(*_) -> None: + pass + + +class LinearBasedGramianComputer(ModuleBasedGramianComputer): + def __init__(self, module: nn.Linear): + super().__init__(module) + + def _compute_gramian( + self, + _: tuple[Tensor, ...], + jac_outputs1: tuple[Tensor, ...], + jac_outputs2: tuple[Tensor, ...], + args: tuple[PyTree, ...], + __: dict[str, PyTree], + ) -> Tensor: + + X = args[0] + dY1 = jac_outputs1[0] + dY2 = jac_outputs2[0] + + # TODO: add support for ndim==4 or find solution that works for any ndim. + if dY1.ndim == 2: + G = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1]) + if self.module.bias is not None: + G += torch.einsum(dY1, [0, 2], dY2, [1, 2], [0, 1]) + elif dY1.ndim == 3: # Typical in transformers + G = torch.einsum(dY1, [0, 2, 4], X, [0, 2, 5], X, [1, 3, 5], dY2, [1, 3, 4], [0, 1]) + if self.module.bias is not None: + G += torch.einsum(dY1, [0, 2, 4], dY2, [1, 3, 4], [0, 1]) + else: + raise ValueError("Higher dimensions not supported. Open an issue if needed.") + + return G