Skip to content

Commit 9884888

Browse files
committed
Add ModuleBasedGramianComputer
1 parent 956e6ce commit 9884888

File tree

1 file changed

+63
-26
lines changed

1 file changed

+63
-26
lines changed

src/torchjd/autogram/_gramian_computer.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -81,31 +81,55 @@ def __call__(
8181
return None
8282

8383

84-
class LinearBasedGramianComputer(GramianComputer):
85-
def __init__(self, module: nn.Linear):
84+
class ModuleBasedGramianComputer(GramianComputer, ABC):
85+
def __init__(self, module: nn.Module):
8686
self.module = module
8787

8888
def __call__(
8989
self,
90-
_: tuple[Tensor, ...],
90+
rg_outputs: tuple[Tensor, ...],
9191
grad_outputs: tuple[Tensor, ...],
9292
args: tuple[PyTree, ...],
93-
__: dict[str, PyTree],
94-
) -> Optional[Tensor]:
95-
96-
X = args[0]
97-
dY = grad_outputs[0]
98-
99-
gramian = ComputeGramian.apply(self._compute_gramian, dY, X)
93+
kwargs: dict[str, PyTree],
94+
) -> Tensor:
95+
gramian = ComputeGramian.apply(
96+
self._compute_gramian, rg_outputs, grad_outputs, args, kwargs
97+
)
10098
return gramian
10199

102-
def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
100+
@abstractmethod
101+
def _compute_gramian(
102+
self,
103+
rg_outputs: tuple[Tensor, ...],
104+
jac_outputs1: tuple[Tensor, ...],
105+
jac_outputs2: tuple[Tensor, ...],
106+
args: tuple[PyTree, ...],
107+
kwargs: dict[str, PyTree],
108+
) -> Tensor:
103109
"""
104-
X is a matrix of shape [k, n] and dY1, dY2 are matrices of shape [k, m].
105-
Returns the dY1 @ G @ dY2 where G is the Gramian of the Jacobian of the module output w.r.t.
106-
to the module params.
110+
If G is the Gramian of the Jacobian of the model's output w.r.t. the parameters, and J1, J2
111+
are the jac_outputs (Jacobian of losses w.r.t. outputs), then this should compute the matrix
112+
J1 @ G @ J2.T
107113
"""
108114

115+
116+
class LinearBasedGramianComputer(ModuleBasedGramianComputer):
117+
def __init__(self, module: nn.Linear):
118+
super().__init__(module)
119+
120+
def _compute_gramian(
121+
self,
122+
_: tuple[Tensor, ...],
123+
jac_outputs1: tuple[Tensor, ...],
124+
jac_outputs2: tuple[Tensor, ...],
125+
args: tuple[PyTree, ...],
126+
__: dict[str, PyTree],
127+
) -> Tensor:
128+
129+
X = args[0]
130+
dY1 = jac_outputs1[0]
131+
dY2 = jac_outputs2[0]
132+
109133
# TODO: add support for ndim==4 or find solution that works for any ndim.
110134
if dY1.ndim == 2:
111135
G = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1])
@@ -124,33 +148,46 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
124148
class ComputeGramian(torch.autograd.Function):
125149
@staticmethod
126150
def forward(
127-
compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor],
128-
dY: Tensor,
129-
X: Tensor,
151+
compute_gramian_fn: Callable[
152+
[
153+
tuple[Tensor, ...],
154+
tuple[Tensor, ...],
155+
tuple[Tensor, ...],
156+
tuple[PyTree, ...],
157+
dict[str, PyTree],
158+
],
159+
Tensor,
160+
],
161+
rg_outputs: tuple[Tensor, ...],
162+
grad_outputs: tuple[Tensor, ...],
163+
args: tuple[PyTree, ...],
164+
kwargs: dict[str, PyTree],
130165
) -> Tensor:
131166
# There is no non-batched dimension
132-
gramian = compute_gramian_fn(dY, dY, X)
167+
gramian = compute_gramian_fn(rg_outputs, grad_outputs, grad_outputs, args, kwargs)
133168
return gramian
134169

135170
@staticmethod
136171
def vmap(
137172
_,
138-
in_dims: tuple[None, tuple[int, ...], None],
139-
compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor],
140-
dY: Tensor,
141-
X: Tensor,
173+
in_dims: tuple[None, None, tuple[int, ...], None, None],
174+
compute_gramian_fn: Callable,
175+
rg_outputs: tuple[Tensor, ...],
176+
jac_outputs: tuple[Tensor, ...],
177+
args: tuple[PyTree, ...],
178+
kwargs: dict[str, PyTree],
142179
) -> tuple[Tensor, None]:
143180
# There is a non-batched dimension
144181
generalized_gramian = torch.vmap(
145182
torch.vmap(
146183
compute_gramian_fn,
147-
in_dims=(in_dims[1], None, None),
184+
in_dims=(None, in_dims[2], None, None, None),
148185
out_dims=0,
149186
),
150-
in_dims=(None, in_dims[1], None),
187+
in_dims=(None, None, in_dims[2], None, None),
151188
out_dims=-1,
152-
)(dY, dY, X)
153-
shape = dY.shape
189+
)(rg_outputs, jac_outputs, jac_outputs, args, kwargs)
190+
shape = generalized_gramian.shape
154191
gramian = reshape_gramian(generalized_gramian, [shape[0] * shape[1]])
155192
return gramian, None
156193

0 commit comments

Comments
 (0)