Skip to content

Conversation

@ValerianRey
Copy link
Contributor

@ValerianRey ValerianRey commented Oct 23, 2025

Basically the idea of this PR is to remove the batched optimization, because this optimization should be made internally by backpropagating a diagonal sparse jacobian.

Compared to main, this simplifies a lot of things. The batched optimization was done in FunctionalJacobianComputer, but required a different usage compared to AutogradJacobianComputer, which made the engine require special cases based on the batch_dim, which in turn required the user to provide the batch_dim. I think all of this can be dropped.

  • Remove batch_dim parameter from Engine
  • Remove test_batched_non_batched_equivalence and test_batched_non_batched_equivalence_2
  • Adapt all usages of Engine to not provide batch_dim
  • Remove FunctionalJacobianComputer
  • Remove args and kwargs from interface of JacobianComputer, GramianComputer and JacobianAccumulator because they were only needed for the functional interface
  • Remove kwargs from interface of Hook and stop registering it with with_kwargs=True (args are mandatory though, so rename them as _).
  • Change JacobianComputer to compute generalized jacobians (shape [m0, ..., mk, n]) and change GramianComputer to compute optional generalized gramians (shape [m0, ..., mk, mk, ..., m0])
  • Change engine.compute_gramian to always simply do one vmap level per dimension of the output, without caring about the batch_dim.
  • Remove all reshapes and movedims in engine.compute_gramian: we don't need reshape anymore since the gramian is directly a generalized gramian, and we dont need movedim anymore since we vmap on all dimensions the same way, without having to put the non-batched dim in front. Merge compute_gramian and _compute_square_gramian.
  • Add temporary function to create the inital jac_output (dense). This should be updated to a tensor format that is optimized for batched computations.

* Remove FunctionalJacobianComputer
* Remove args and kwargs from interface of JacobianComputer, GramianComputer and JacobianAccumulator because they were only needed for the functional interface
* Remove kwargs from interface of Hook and stop registering it with with_kwargs=True (args are mandatory though, so rename them as _).
* Change JacobianComputer to compute generalized jacobians (shape [m0, ..., mk, n]) and change GramianComputer to compute optional generalized gramians (shape [m0, ..., mk, mk, ..., m0])
* Change engine.compute_gramian to always simply do one vmap level per dimension of the output, without caring about the batch_dim.
* Remove all reshapes and movedims in engine.compute_gramian: we don't need reshape anymore since the gramian is directly a generalized gramian, and we dont need movedim anymore since we vmap on all dimensions the same way, without having to put the non-batched dim in front. Merge compute_gramian and _compute_square_gramian.
* Use a DiagonalSparseTensor as initial jac_output of compute_gramian.
@ValerianRey ValerianRey added feat New feature or request package: autogram labels Oct 23, 2025
@ValerianRey ValerianRey self-assigned this Oct 23, 2025
@codecov
Copy link

codecov bot commented Oct 23, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/autogram/_engine.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_gramian_computer.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_jacobian_computer.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_module_hook_manager.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ValerianRey
Copy link
Contributor Author

This PR is a pre-requesite to be able to use DiagonalSparseTensors. It highly simplifies the engine, making all the necessary changes so that the optimization is now all about what type of tensor we give a jac_output.

So in a future PR (after #466 is merged), we will be able to simply change:

jac_output = _make_initial_jac_output(output)

by

jac_output = DiagonalSparseTensor(...)

and to remove _make_initial_jac_output.

In fact, it even works if we cherry-pick this into #466 and use a DiagonalSparseTensor as jac_output, but it's densified super quickly so it's not really using sparsity.

@ValerianRey ValerianRey changed the title feat(autogram): Stop using functional interface feat(autogram): Remove batched optimizations Oct 23, 2025
Copy link
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bah bravo Nils, le code il a disparu mais c'est pas grave. (LGTM after few discussions)

@ValerianRey ValerianRey merged commit 481d83f into dev-new-engine Oct 23, 2025
17 checks passed
@ValerianRey ValerianRey deleted the simplify-engine branch October 23, 2025 17:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feat New feature or request package: autogram

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants