Skip to content

Conversation

@PierreQuinton
Copy link
Contributor

No description provided.

@PierreQuinton PierreQuinton added feat New feature or request package: autogram labels Oct 15, 2025
@ValerianRey
Copy link
Contributor

ValerianRey commented Oct 15, 2025

This is big for architectures with very big linear layers, like AlexNet.
For AlexNet, on cuda, with batch_dim=0, this leads to:

  • Double max batch size (from batch_size=19 to batch_size=38 on my gpu). Sadly, this is still very far from the max batch size of 1268 of SGD with autograd). Do you think there might be a theoretical way to bridge this gap even more? EDIT: I have no idea why, but re-running the same tests (or maybe I did a mistake before) yields completely different results. Max batch size is now 18 for main and 468 for this PR - much, much closer to the 1268 of autograd). EDIT2: the big memory improvement (and a small speed improvement) comes from just installing opt_einsum (without even changing the code). uv pip install opt_einsum.
  • x2 to x4 speed (depending on the batch size) of the whole autogram_forward_backward function (so this includes not only the gramian computation of the linear layers, but also of all other layers + the forward passes).

For other architectures, the differences are not very noticeable though. But this is very promising. Let's fully focus on this direction IMO.

"""

G_b = torch.einsum("ik,jk->ij", dY1, dY2)
G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be replaced by:

G_W = oe.contract("ik,il,jl,jk->ij", dY1, X, X, dY2, optimize="optimal", backend="torch")

with import opt_einsum as oe
but it seems to be the exact same runtime and memory usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually whenever opt_einsum is installed, the contraction is already done even without changing the line:

G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2)

We could still add the line just to make it explicit maybe.

Copy link
Contributor Author

@PierreQuinton PierreQuinton Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whatever you prefer. I prefer not having to give the two additional parameters, for me what is important here is

  1. It is an einsum
  2. It is fast

But the second criteria is more of an "how" than a "what" so we don't really need to know. For this reason I would vouch slightly for torch.einsum. The negative part is that a user could set the global settings of opt_einsum to non-optimized thereby making it slow, but I guess that is the user's responsability.


def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
"""
X is a matrix of shape [k, n] and dY1, dY2 are matrices of shape [k, m].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's actually no guarantee that X, dY1 and dY2 are matrices.

From the documentation of nn.Linear:

Image

In particular, when there is no batch dim, I think the * dimension could be empty, and in transformers, the * dimension is (batch_size, seq_length), which is why transformers fail with this PR.

Copy link
Contributor

@ValerianRey ValerianRey Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it work on Transformers with that:

if dY1.ndim == 1:
    G_b = torch.einsum("k,k->", dY1, dY2)
    G_W = torch.einsum("k,l,l,k->", dY1, X, X, dY2)
elif dY1.ndim == 2:
    G_b = torch.einsum("ak,ik->ai", dY1, dY2)
    G_W = torch.einsum("ak,al,il,ik->ai", dY1, X, X, dY2)
elif dY1.ndim == 3:  # Typical in transformers
    G_b = torch.einsum("abk,ijk->ai", dY1, dY2)
    G_W = torch.einsum("abk,abl,ijl,ijk->ai", dY1, X, X, dY2)
else:
    raise ValueError("Higher dimensions not supported. Open an issue if needed.")

Not elegant at all but it seems to work. Maybe there's a clean way to write this that works for any number of dimensions without having ifs. Also, please review the equations. I did them basically with trial and error until the tests passed.

Copy link
Contributor Author

@PierreQuinton PierreQuinton Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it needs to be at least matrices (2<=ndim) as we know it's a batched scenario. However, We could in principle add the no batched dimension scenario, but I'm not sure it would be faster than the classical Jacobian based GramianComputer.

I did them basically with trial and error until the tests passed.

I which I could have done that ^^

@ValerianRey
Copy link
Contributor

ValerianRey commented Oct 15, 2025

Also pretty big for Transformers (with the change i suggested to handle higher order tensors).

Times for forward + backward on WithTransformerLarge with BS=256, A=Mean on cuda:0.
Reduced from 3.13 sec (main) to 2.20 (this PR)

Memory is however increased. Max batch size went from 273 (main) to 256 (this PR).

@ValerianRey
Copy link
Contributor

ValerianRey commented Oct 15, 2025

This seems to break NoFreeParam (tiny errors) and ModuleReuse (large errors). Need to investigate that.

For ModuleReuse, my guess is that it simply doesn't consider cross terms anymore, so it's normal that it fails the test.

@PierreQuinton
Copy link
Contributor Author

PierreQuinton commented Oct 16, 2025

Of interest: https://optimized-einsum.readthedocs.io/en/stable/reusing_paths.html
If we can try to explore what contraction is the optimal and if it is essentially always the same, then we may want to use a self forged contraction. It would be very helpful to know the optimal contraction order.

@PierreQuinton PierreQuinton changed the title feat(autogram): Add LinearBasedGramianComputer. feat(autogram): Add ModuleBasedGramianComputer. Oct 17, 2025
@ValerianRey
Copy link
Contributor

@PierreQuinton I found a way to compute the gramian with autograd with no cross terms from module reuse / inter-module param reuse: 30fdc00. Basically, the idea is to have a module pre-hook that clones each parameter before using them, and a module post-hook that restores to the original params. This way, each module usage corresponds to a different clone, and you can compute a gradient wrt each clone. The implementation is with a context manager so it's quite clean IMO. Current limitations:

  • Does not work on WithMultiheadAttention, WithTransformer and WithFreeParam, because they all involve some indirect parameters for the hooked module. Need to investigate and fix that (it's probably doable).
  • Still counts cross-terms from intra-module parameter reuse: we'd need a node-based algo (rather than module-based) to fix that. But since autogram is still module based, it doesn't matter yet.

* Add ModuleFactory and use it to instantiate models in tests
* Add get_in_out_shapes and use it to obtain input and output shapes in tests
* Move conftest.py from tests.unit to tests
* Separate DEVICE creation from conftest.py to device.py
* Add pytest_make_parametrize_id
* Add support for RNN, BatchNorm2d and InstanceNorm2d in get_in_out_shapes
* Remove WithRNN, WithBatchNorm and WithModuleTrackingRunningStats - use simple factories instead
* Revert removal of WithRNN (part of aff0abc)
* Fix output of WithRNN to not include the hidden state
* Extract rng forking into contexts.py
* Make _forward_pass do rng forking
* Make _forward_pass take reduction parameter
* Make forward_pass public
* Use forward_pass in test_engine.py, stop reseeding (it's now done by forward_pass)
* Make zipping strict in make_mse_loss_fn
* Stop requiring params in autograd_gramian_forward_backward
* Improve parameter order of autogram_forward_backward
* Rename some variables
* Factorize input and target creation into make_inputs_and_targets
* Reorder some code
* Add CloneParams context to consider each parameter usage on a per-module-usage basis.
* Add _get_losses_and_params_with_cross_terms, _get_losses_and_params_without_cross_terms, and _get_losses_and_params to select between both.
@ValerianRey ValerianRey force-pushed the linear-gramian-computer branch from 1d134b4 to 7a95b96 Compare October 20, 2025 15:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feat New feature or request package: autogram

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants