Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
from ._gramian_computer import (
GramianComputer,
JacobianBasedGramianComputerWithCrossTerms,
LinearBasedGramianComputer,
)
from ._gramian_utils import movedim_gramian, reshape_gramian
from ._jacobian_computer import (
AutogradJacobianComputer,
Expand Down Expand Up @@ -203,11 +207,16 @@ def _hook_module_recursively(self, module: nn.Module) -> None:

def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
jacobian_computer: JacobianComputer
gramian_computer: GramianComputer
if self._batch_dim is not None:
jacobian_computer = FunctionalJacobianComputer(module)
if isinstance(module, nn.Linear):
gramian_computer = LinearBasedGramianComputer(module)
else:
jacobian_computer = FunctionalJacobianComputer(module)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
else:
jacobian_computer = AutogradJacobianComputer(module)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)

return gramian_computer

Expand Down
74 changes: 73 additions & 1 deletion src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Optional

from torch import Tensor
import torch
from torch import Tensor, nn
from torch.utils._pytree import PyTree

from torchjd.autogram._gramian_utils import reshape_gramian
from torchjd.autogram._jacobian_computer import JacobianComputer


Expand Down Expand Up @@ -76,3 +79,72 @@ def __call__(
return gramian
else:
return None


class LinearBasedGramianComputer(GramianComputer):
def __init__(self, module: nn.Linear):
self.module = module

def __call__(
self,
_: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
__: dict[str, PyTree],
) -> Optional[Tensor]:

X = args[0]
dY = grad_outputs[0]

gramian = ComputeLinearGramian.apply(self._compute_gramian, dY, X)
return gramian

def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
"""
X is a matrix of shape [k, n] and dY1, dY2 are matrices of shape [k, m].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's actually no guarantee that X, dY1 and dY2 are matrices.

From the documentation of nn.Linear:

Image

In particular, when there is no batch dim, I think the * dimension could be empty, and in transformers, the * dimension is (batch_size, seq_length), which is why transformers fail with this PR.

Copy link
Contributor

@ValerianRey ValerianRey Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it work on Transformers with that:

if dY1.ndim == 1:
    G_b = torch.einsum("k,k->", dY1, dY2)
    G_W = torch.einsum("k,l,l,k->", dY1, X, X, dY2)
elif dY1.ndim == 2:
    G_b = torch.einsum("ak,ik->ai", dY1, dY2)
    G_W = torch.einsum("ak,al,il,ik->ai", dY1, X, X, dY2)
elif dY1.ndim == 3:  # Typical in transformers
    G_b = torch.einsum("abk,ijk->ai", dY1, dY2)
    G_W = torch.einsum("abk,abl,ijl,ijk->ai", dY1, X, X, dY2)
else:
    raise ValueError("Higher dimensions not supported. Open an issue if needed.")

Not elegant at all but it seems to work. Maybe there's a clean way to write this that works for any number of dimensions without having ifs. Also, please review the equations. I did them basically with trial and error until the tests passed.

Copy link
Contributor Author

@PierreQuinton PierreQuinton Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it needs to be at least matrices (2<=ndim) as we know it's a batched scenario. However, We could in principle add the no batched dimension scenario, but I'm not sure it would be faster than the classical Jacobian based GramianComputer.

I did them basically with trial and error until the tests passed.

I which I could have done that ^^

Returns the dY1 @ G @ dY2 where G is the Gramian of the Jacobian of the module output w.r.t.
to the module params.
"""

G_b = torch.einsum("ik,jk->ij", dY1, dY2)
G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be replaced by:

G_W = oe.contract("ik,il,jl,jk->ij", dY1, X, X, dY2, optimize="optimal", backend="torch")

with import opt_einsum as oe
but it seems to be the exact same runtime and memory usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually whenever opt_einsum is installed, the contraction is already done even without changing the line:

G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2)

We could still add the line just to make it explicit maybe.

Copy link
Contributor Author

@PierreQuinton PierreQuinton Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whatever you prefer. I prefer not having to give the two additional parameters, for me what is important here is

  1. It is an einsum
  2. It is fast

But the second criteria is more of an "how" than a "what" so we don't really need to know. For this reason I would vouch slightly for torch.einsum. The negative part is that a user could set the global settings of opt_einsum to non-optimized thereby making it slow, but I guess that is the user's responsability.


return G_b + G_W


class ComputeLinearGramian(torch.autograd.Function):
@staticmethod
def forward(
compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor],
dY: Tensor,
X: Tensor,
) -> Tensor:
# There is no non-batched dimension
gramian = compute_gramian_fn(dY, dY, X)
return gramian

@staticmethod
def vmap(
_,
in_dims: tuple[None, tuple[int, ...], None],
compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor],
dY: Tensor,
X: Tensor,
) -> tuple[Tensor, None]:
# There is a non-batched dimension
generalized_gramian = torch.vmap(
torch.vmap(
compute_gramian_fn,
in_dims=(in_dims[1], None, None),
out_dims=0,
),
in_dims=(None, in_dims[1], None),
out_dims=-1,
)(dY, dY, X)
shape = dY.shape
gramian = reshape_gramian(generalized_gramian, [shape[0] * shape[1]])
return gramian, None

@staticmethod
def setup_context(*_) -> None:
pass
Loading