Skip to content

Commit ac384a0

Browse files
committed
Reorder functions
1 parent 9884888 commit ac384a0

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

src/torchjd/autogram/_gramian_computer.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -113,38 +113,6 @@ def _compute_gramian(
113113
"""
114114

115115

116-
class LinearBasedGramianComputer(ModuleBasedGramianComputer):
117-
def __init__(self, module: nn.Linear):
118-
super().__init__(module)
119-
120-
def _compute_gramian(
121-
self,
122-
_: tuple[Tensor, ...],
123-
jac_outputs1: tuple[Tensor, ...],
124-
jac_outputs2: tuple[Tensor, ...],
125-
args: tuple[PyTree, ...],
126-
__: dict[str, PyTree],
127-
) -> Tensor:
128-
129-
X = args[0]
130-
dY1 = jac_outputs1[0]
131-
dY2 = jac_outputs2[0]
132-
133-
# TODO: add support for ndim==4 or find solution that works for any ndim.
134-
if dY1.ndim == 2:
135-
G = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1])
136-
if self.module.bias is not None:
137-
G += torch.einsum(dY1, [0, 2], dY2, [1, 2], [0, 1])
138-
elif dY1.ndim == 3: # Typical in transformers
139-
G = torch.einsum(dY1, [0, 2, 4], X, [0, 2, 5], X, [1, 3, 5], dY2, [1, 3, 4], [0, 1])
140-
if self.module.bias is not None:
141-
G += torch.einsum(dY1, [0, 2, 4], dY2, [1, 3, 4], [0, 1])
142-
else:
143-
raise ValueError("Higher dimensions not supported. Open an issue if needed.")
144-
145-
return G
146-
147-
148116
class ComputeGramian(torch.autograd.Function):
149117
@staticmethod
150118
def forward(
@@ -194,3 +162,35 @@ def vmap(
194162
@staticmethod
195163
def setup_context(*_) -> None:
196164
pass
165+
166+
167+
class LinearBasedGramianComputer(ModuleBasedGramianComputer):
168+
def __init__(self, module: nn.Linear):
169+
super().__init__(module)
170+
171+
def _compute_gramian(
172+
self,
173+
_: tuple[Tensor, ...],
174+
jac_outputs1: tuple[Tensor, ...],
175+
jac_outputs2: tuple[Tensor, ...],
176+
args: tuple[PyTree, ...],
177+
__: dict[str, PyTree],
178+
) -> Tensor:
179+
180+
X = args[0]
181+
dY1 = jac_outputs1[0]
182+
dY2 = jac_outputs2[0]
183+
184+
# TODO: add support for ndim==4 or find solution that works for any ndim.
185+
if dY1.ndim == 2:
186+
G = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1])
187+
if self.module.bias is not None:
188+
G += torch.einsum(dY1, [0, 2], dY2, [1, 2], [0, 1])
189+
elif dY1.ndim == 3: # Typical in transformers
190+
G = torch.einsum(dY1, [0, 2, 4], X, [0, 2, 5], X, [1, 3, 5], dY2, [1, 3, 4], [0, 1])
191+
if self.module.bias is not None:
192+
G += torch.einsum(dY1, [0, 2, 4], dY2, [1, 3, 4], [0, 1])
193+
else:
194+
raise ValueError("Higher dimensions not supported. Open an issue if needed.")
195+
196+
return G

0 commit comments

Comments
 (0)