Skip to content

Commit 5221159

Browse files
committed
Add LinearBasedGramianComputer.
1 parent 628e0b0 commit 5221159

File tree

2 files changed

+85
-4
lines changed

2 files changed

+85
-4
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
from ._edge_registry import EdgeRegistry
88
from ._gramian_accumulator import GramianAccumulator
9-
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
9+
from ._gramian_computer import (
10+
GramianComputer,
11+
JacobianBasedGramianComputerWithCrossTerms,
12+
LinearBasedGramianComputer,
13+
)
1014
from ._gramian_utils import movedim_gramian, reshape_gramian
1115
from ._jacobian_computer import (
1216
AutogradJacobianComputer,
@@ -203,11 +207,16 @@ def _hook_module_recursively(self, module: nn.Module) -> None:
203207

204208
def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
205209
jacobian_computer: JacobianComputer
210+
gramian_computer: GramianComputer
206211
if self._batch_dim is not None:
207-
jacobian_computer = FunctionalJacobianComputer(module)
212+
if isinstance(module, nn.Linear):
213+
gramian_computer = LinearBasedGramianComputer(module)
214+
else:
215+
jacobian_computer = FunctionalJacobianComputer(module)
216+
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
208217
else:
209218
jacobian_computer = AutogradJacobianComputer(module)
210-
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
219+
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
211220

212221
return gramian_computer
213222

src/torchjd/autogram/_gramian_computer.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from abc import ABC, abstractmethod
2+
from collections.abc import Callable
23
from typing import Optional
34

4-
from torch import Tensor
5+
import torch
6+
from torch import Tensor, nn
57
from torch.utils._pytree import PyTree
68

9+
from torchjd.autogram._gramian_utils import reshape_gramian
710
from torchjd.autogram._jacobian_computer import JacobianComputer
811

912

@@ -76,3 +79,72 @@ def __call__(
7679
return gramian
7780
else:
7881
return None
82+
83+
84+
class LinearBasedGramianComputer(GramianComputer):
85+
def __init__(self, module: nn.Linear):
86+
self.module = module
87+
88+
def __call__(
89+
self,
90+
_: tuple[Tensor, ...],
91+
grad_outputs: tuple[Tensor, ...],
92+
args: tuple[PyTree, ...],
93+
__: dict[str, PyTree],
94+
) -> Optional[Tensor]:
95+
96+
X = args[0]
97+
dY = grad_outputs[0]
98+
99+
gramian = ComputeLinearGramian.apply(self._compute_gramian, dY, X)
100+
return gramian
101+
102+
def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
103+
"""
104+
X is a matrix of shape [k, n] and dY1, dY2 are matrices of shape [k, m].
105+
Returns the dY1 @ G @ dY2 where G is the Gramian of the Jacobian of the module output w.r.t.
106+
to the module params.
107+
"""
108+
109+
G_b = torch.einsum("ik,jk->ij", dY1, dY2)
110+
G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2)
111+
112+
return G_b + G_W
113+
114+
115+
class ComputeLinearGramian(torch.autograd.Function):
116+
@staticmethod
117+
def forward(
118+
compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor],
119+
dY: Tensor,
120+
X: Tensor,
121+
) -> Tensor:
122+
# There is no non-batched dimension
123+
gramian = compute_gramian_fn(dY, dY, X)
124+
return gramian
125+
126+
@staticmethod
127+
def vmap(
128+
_,
129+
in_dims: tuple[None, tuple[int, ...], None],
130+
compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor],
131+
dY: Tensor,
132+
X: Tensor,
133+
) -> tuple[Tensor, None]:
134+
# There is a non-batched dimension
135+
generalized_gramian = torch.vmap(
136+
torch.vmap(
137+
compute_gramian_fn,
138+
in_dims=(in_dims[1], None, None),
139+
out_dims=0,
140+
),
141+
in_dims=(None, in_dims[1], None),
142+
out_dims=-1,
143+
)(dY, dY, X)
144+
shape = dY.shape
145+
gramian = reshape_gramian(generalized_gramian, [shape[0] * shape[1]])
146+
return gramian, None
147+
148+
@staticmethod
149+
def setup_context(*_) -> None:
150+
pass

0 commit comments

Comments
 (0)