Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ authors = [
]
requires-python = ">=3.10"
dependencies = [
"torch>=2.0.0",
"torch[opt-einsum]>=2.0.0",
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"numpy>=1.21.0", # Does not work before 1.21
"qpsolvers>=1.0.1", # Does not work before 1.0.1
Expand Down
17 changes: 13 additions & 4 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
from ._gramian_computer import (
GramianComputer,
JacobianBasedGramianComputerWithCrossTerms,
LinearBasedGramianComputer,
)
from ._gramian_utils import movedim_gramian, reshape_gramian
from ._jacobian_computer import (
AutogradJacobianComputer,
Expand Down Expand Up @@ -203,11 +207,16 @@ def _hook_module_recursively(self, module: nn.Module) -> None:

def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
jacobian_computer: JacobianComputer
gramian_computer: GramianComputer
if self._batch_dim is not None:
jacobian_computer = FunctionalJacobianComputer(module)
if isinstance(module, nn.Linear):
gramian_computer = LinearBasedGramianComputer(module)
else:
jacobian_computer = FunctionalJacobianComputer(module)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
else:
jacobian_computer = AutogradJacobianComputer(module)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)

return gramian_computer

Expand Down Expand Up @@ -321,7 +330,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
non_batch_dim_len = output.shape[0]
identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype)
ones = torch.ones_like(output[0])
jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones)
jac_output = torch.einsum(identity_matrix, [0, 1], ones, [...], [0, 1, ...])

_ = vmap(differentiation)(jac_output)
else:
Expand Down
120 changes: 119 additions & 1 deletion src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Optional

from torch import Tensor
import torch
from torch import Tensor, nn
from torch.utils._pytree import PyTree

from torchjd.autogram._gramian_utils import reshape_gramian
from torchjd.autogram._jacobian_computer import JacobianComputer


Expand Down Expand Up @@ -76,3 +79,118 @@ def __call__(
return gramian
else:
return None


class ModuleBasedGramianComputer(GramianComputer, ABC):
def __init__(self, module: nn.Module):
self.module = module

def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
gramian = ComputeGramian.apply(
self._compute_gramian, rg_outputs, grad_outputs, args, kwargs
)
return gramian

@abstractmethod
def _compute_gramian(
self,
rg_outputs: tuple[Tensor, ...],
jac_outputs1: tuple[Tensor, ...],
jac_outputs2: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
"""
If G is the Gramian of the Jacobian of the model's output w.r.t. the parameters, and J1, J2
are the jac_outputs (Jacobian of losses w.r.t. outputs), then this should compute the matrix
J1 @ G @ J2.T
"""


class ComputeGramian(torch.autograd.Function):
@staticmethod
def forward(
compute_gramian_fn: Callable[
[
tuple[Tensor, ...],
tuple[Tensor, ...],
tuple[Tensor, ...],
tuple[PyTree, ...],
dict[str, PyTree],
],
Tensor,
],
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
# There is no non-batched dimension
gramian = compute_gramian_fn(rg_outputs, grad_outputs, grad_outputs, args, kwargs)
return gramian

@staticmethod
def vmap(
_,
in_dims: tuple[None, None, tuple[int, ...], None, None],
compute_gramian_fn: Callable,
rg_outputs: tuple[Tensor, ...],
jac_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> tuple[Tensor, None]:
# There is a non-batched dimension
generalized_gramian = torch.vmap(
torch.vmap(
compute_gramian_fn,
in_dims=(None, in_dims[2], None, None, None),
out_dims=0,
),
in_dims=(None, None, in_dims[2], None, None),
out_dims=-1,
)(rg_outputs, jac_outputs, jac_outputs, args, kwargs)
shape = generalized_gramian.shape
gramian = reshape_gramian(generalized_gramian, [shape[0] * shape[1]])
return gramian, None

@staticmethod
def setup_context(*_) -> None:
pass


class LinearBasedGramianComputer(ModuleBasedGramianComputer):
def __init__(self, module: nn.Linear):
super().__init__(module)

def _compute_gramian(
self,
_: tuple[Tensor, ...],
jac_outputs1: tuple[Tensor, ...],
jac_outputs2: tuple[Tensor, ...],
args: tuple[PyTree, ...],
__: dict[str, PyTree],
) -> Tensor:

X = args[0]
dY1 = jac_outputs1[0]
dY2 = jac_outputs2[0]

# TODO: add support for ndim==4 or find solution that works for any ndim.
if dY1.ndim == 2:
G = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1])
if self.module.bias is not None:
G += torch.einsum(dY1, [0, 2], dY2, [1, 2], [0, 1])
elif dY1.ndim == 3: # Typical in transformers
G = torch.einsum(dY1, [0, 2, 4], X, [0, 2, 5], X, [1, 3, 5], dY2, [1, 3, 4], [0, 1])
if self.module.bias is not None:
G += torch.einsum(dY1, [0, 2, 4], dY2, [1, 3, 4], [0, 1])
else:
raise ValueError("Higher dimensions not supported. Open an issue if needed.")

return G
Loading