-
Notifications
You must be signed in to change notification settings - Fork 10
feat(autogram): Add ModuleBasedGramianComputer.
#458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
5221159
6e7d051
8862c16
0f1c909
33a9721
47363bd
99e4c78
956e6ce
9884888
ac384a0
d8c54e4
ff52d98
85a16fc
ea21281
850a2eb
6ea78c0
7a95b96
5f72220
a4101ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,12 @@ | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Callable | ||
| from typing import Optional | ||
|
|
||
| from torch import Tensor | ||
| import torch | ||
| from torch import Tensor, nn | ||
| from torch.utils._pytree import PyTree | ||
|
|
||
| from torchjd.autogram._gramian_utils import reshape_gramian | ||
| from torchjd.autogram._jacobian_computer import JacobianComputer | ||
|
|
||
|
|
||
|
|
@@ -76,3 +79,72 @@ def __call__( | |
| return gramian | ||
| else: | ||
| return None | ||
|
|
||
|
|
||
| class LinearBasedGramianComputer(GramianComputer): | ||
| def __init__(self, module: nn.Linear): | ||
| self.module = module | ||
|
|
||
| def __call__( | ||
| self, | ||
| _: tuple[Tensor, ...], | ||
| grad_outputs: tuple[Tensor, ...], | ||
| args: tuple[PyTree, ...], | ||
| __: dict[str, PyTree], | ||
| ) -> Optional[Tensor]: | ||
|
|
||
| X = args[0] | ||
| dY = grad_outputs[0] | ||
|
|
||
| gramian = ComputeLinearGramian.apply(self._compute_gramian, dY, X) | ||
| return gramian | ||
|
|
||
| def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor: | ||
| """ | ||
| X is a matrix of shape [k, n] and dY1, dY2 are matrices of shape [k, m]. | ||
| Returns the dY1 @ G @ dY2 where G is the Gramian of the Jacobian of the module output w.r.t. | ||
| to the module params. | ||
| """ | ||
|
|
||
| G_b = torch.einsum("ik,jk->ij", dY1, dY2) | ||
| G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2) | ||
|
||
|
|
||
| return G_b + G_W | ||
|
|
||
|
|
||
| class ComputeLinearGramian(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
| compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor], | ||
| dY: Tensor, | ||
| X: Tensor, | ||
| ) -> Tensor: | ||
| # There is no non-batched dimension | ||
| gramian = compute_gramian_fn(dY, dY, X) | ||
| return gramian | ||
|
|
||
| @staticmethod | ||
| def vmap( | ||
| _, | ||
| in_dims: tuple[None, tuple[int, ...], None], | ||
| compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor], | ||
| dY: Tensor, | ||
| X: Tensor, | ||
| ) -> tuple[Tensor, None]: | ||
| # There is a non-batched dimension | ||
| generalized_gramian = torch.vmap( | ||
| torch.vmap( | ||
| compute_gramian_fn, | ||
| in_dims=(in_dims[1], None, None), | ||
| out_dims=0, | ||
| ), | ||
| in_dims=(None, in_dims[1], None), | ||
| out_dims=-1, | ||
| )(dY, dY, X) | ||
| shape = dY.shape | ||
| gramian = reshape_gramian(generalized_gramian, [shape[0] * shape[1]]) | ||
| return gramian, None | ||
|
|
||
| @staticmethod | ||
| def setup_context(*_) -> None: | ||
| pass | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's actually no guarantee that X, dY1 and dY2 are matrices.
From the documentation of nn.Linear:
In particular, when there is no batch dim, I think the * dimension could be empty, and in transformers, the * dimension is (batch_size, seq_length), which is why transformers fail with this PR.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it work on Transformers with that:
Not elegant at all but it seems to work. Maybe there's a clean way to write this that works for any number of dimensions without having ifs. Also, please review the equations. I did them basically with trial and error until the tests passed.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it needs to be at least matrices (
2<=ndim) as we know it's a batched scenario. However, We could in principle add the no batched dimension scenario, but I'm not sure it would be faster than the classical Jacobian based GramianComputer.I which I could have done that ^^