Skip to content

Commit 0101465

Browse files
committed
Add and use JacobianBasedGramianComputerWithoutCrossTerms
1 parent e567e5d commit 0101465

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ._gramian_accumulator import GramianAccumulator
99
from ._gramian_computer import (
1010
GramianComputer,
11-
JacobianBasedGramianComputerWithCrossTerms,
11+
JacobianBasedGramianComputerWithoutCrossTerms,
1212
LinearBasedGramianComputer,
1313
)
1414
from ._gramian_utils import movedim_gramian, reshape_gramian
@@ -213,10 +213,10 @@ def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
213213
gramian_computer = LinearBasedGramianComputer(module)
214214
else:
215215
jacobian_computer = FunctionalJacobianComputer(module)
216-
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
216+
gramian_computer = JacobianBasedGramianComputerWithoutCrossTerms(jacobian_computer)
217217
else:
218218
jacobian_computer = AutogradJacobianComputer(module)
219-
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
219+
gramian_computer = JacobianBasedGramianComputerWithoutCrossTerms(jacobian_computer)
220220

221221
return gramian_computer
222222

src/torchjd/autogram/_gramian_computer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,25 @@ def __call__(
8181
return None
8282

8383

84+
class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer):
85+
"""
86+
Stateful JacobianBasedGramianComputer that directly returning the gramian without considering
87+
cross-terms (except intra-module cross-terms).
88+
"""
89+
90+
def __call__(
91+
self,
92+
rg_outputs: tuple[Tensor, ...],
93+
grad_outputs: tuple[Tensor, ...],
94+
args: tuple[PyTree, ...],
95+
kwargs: dict[str, PyTree],
96+
) -> Optional[Tensor]:
97+
"""Compute what we can for a module and optionally return the gramian if it's ready."""
98+
99+
jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
100+
return self._to_gramian(jacobian_matrix)
101+
102+
84103
class ModuleBasedGramianComputer(GramianComputer, ABC):
85104
def __init__(self, module: nn.Module):
86105
self.module = module

0 commit comments

Comments
 (0)