Skip to content

Commit d159430

Browse files
authored
feat: add autogram.Engine (batched) (#387)
Epic commit squashing all the work on the first version of the autogram engine. Autogram package: * Add autogram package * Add autogram engine (batched only) * Add EdgeRegistry * Add GramianAccumulator * Add ModuleHookManager * Add AccumulateJacobian and JacobianAccumulator nodes * Add get_functional_vjp Aggregation package: * Make all Weighting[PSDMatrix] classes public * Have the same default values in weightings as we have in aggregators * Make weighting wrappers takes pref_vector directly rather than a base weighting * Make Aggregator override __init__ to remove args and kwargs from its prototype * Improve explanation of pref_vector in aggregators and weightings * Put less emphasis on how to use aggregators directly: have a single usage example for both in aggregation/index.html rather than one per aggregator. * Remove table of properties about aggregators in aggregation/index.rst Global structure: * Make backward and mtl_backward importable from torchjd.autojac * Deprecate importing backward and mtl_backward from torchjd * Explain how to deprecate in CONTRIBUTING.md Testing: * Force using deterministic algorithms only on CPU * Add slow marker and --runslow option, and explain how to run slow tests in CONTRIBUTING.md * Add garbage_collect marker for tests that need to free cuda memory after they're run * Reorganize test package: add speed and utils packages * Generalize assert_tensor_dicts_are_close to any pair of dicts whose values are Tensors (rather than Tensor to Tensor dicts only) * Add deprecation tests for the old way of importing backward and mtl_backward * Add torchvision to test dependencies * Add utils/architectures.py with many edge case architectures and a few real-world architectures * Add unit tests for the autogram engine * Add speed tests for the autogram engine * Add unit tests for the edge registry * Add unit tests for the gramian accumulator * Add value tests for Weightings * Move value tests of aggregators from doc tests to unit tests Documentation: * Remove docstring from torchjd/__init__.py * Move some documentation from dedicated rst files to the __init__.py of the package * Reorganize documentation: autojac, autogram, aggregation * Revamp documentation to emphasize more on autogram and weightings * Add partial IWRM example * Update IWRM example * Add documentation entry for autogram.Engine
1 parent 3ac4f12 commit d159430

File tree

128 files changed

+3162
-950
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+3162
-950
lines changed

CHANGELOG.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,36 @@ changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Added
12+
13+
- Added the `autogram` package, with the `autogram.Engine`. This is an implementation of Algorithm 3
14+
from [Jacobian Descent for Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232),
15+
optimized for batched computations, as in IWRM.
16+
- For all `Aggregator`s based on the weighting of the Gramian of the Jacobian, made their
17+
`Weighting` class public. It can be used directly on a Gramian (computed via the
18+
`autogram.Engine`) to extract some weights. The list of new public classes is:
19+
- `Weighting` (abstract base class)
20+
- `UPGradWeighting`
21+
- `AlignedMTLWeighting`
22+
- `CAGradWeighting`
23+
- `ConstantWeighting`
24+
- `DualProjWeighting`
25+
- `IMTLGWeighting`
26+
- `KrumWeighting`
27+
- `MeanWeighting`
28+
- `MGDAWeighting`
29+
- `PCGradWeighting`
30+
- `RandomWeighting`
31+
- `SumWeighting`
32+
- Added usage example for IWRM with autogram.
33+
- Added usage example for IWRM with partial autogram.
34+
35+
### Changed
36+
37+
- Revamped documentation.
38+
- Made `backward` and `mtl_backward` importable from `torchjd.autojac` (like it was prior to 0.7.0).
39+
- Deprecated importing `backward` and `mtl_backward` from `torchjd` directly.
40+
1141
## [0.7.0] - 2025-06-04
1242

1343
### Changed

CONTRIBUTING.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,16 @@ uv run pre-commit install
5151
```
5252

5353
## Running tests
54-
- To verify that your installation was successful, and that all unit tests pass, run:
54+
- To verify that your installation was successful, and that unit tests pass, run:
5555
```bash
5656
uv run pytest tests/unit
5757
```
5858

59+
- To also run the unit tests that are marked as slow, add the `--runslow` flag:
60+
```bash
61+
uv run pytest tests/unit --runslow
62+
```
63+
5964
- If you have access to a cuda-enabled GPU, you should also check that the unit tests pass on it:
6065
```bash
6166
CUBLAS_WORKSPACE_CONFIG=:4096:8 PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unit
@@ -113,19 +118,20 @@ We ask contributors to implement the unit tests necessary to check the correctne
113118
implementations. Besides, whenever usage examples are provided, we require the example's code to be
114119
tested in `tests/doc`. We require a very high code coverage for newly introduced sources (~95-100%).
115120
To ensure that the tensors generated during the tests are on the right device, you have to use the
116-
partial functions defined in `tests/unit/_utils.py` to instantiate tensors. For instance, instead of
121+
partial functions defined in `tests/utils/tensors.py` to instantiate tensors. For instance, instead
122+
of
117123
```python
118124
import torch
119125
a = torch.ones(3, 4)
120126
```
121127
use
122128
```python
123-
from unit._utils import ones_
129+
from utils.tensors import ones_
124130
a = ones_(3, 4)
125131
```
126132
127133
This will automatically call `torch.ones` with `device=unit.conftest.DEVICE`.
128-
If the function you need does not exist yet as a partial function in `_utils.py`, add it.
134+
If the function you need does not exist yet as a partial function in `tensors.py`, add it.
129135
Lastly, when you create a model or a random generator, you have to move them manually to the right
130136
device (the `DEVICE` defined in `unit.conftest`):
131137
```python
@@ -162,6 +168,11 @@ implementation of a mathematical aggregator.
162168
> Before working on the implementation of a new aggregator, please contact us via an issue or a
163169
> discussion: in many cases, we have already thought about it, or even started an implementation.
164170
171+
## Deprecation
172+
173+
To deprecate some public functionality, make it raise a `DeprecationWarning`. A test should also be
174+
added in `tests/units/test_deprecations.py`, ensuring that this warning is issued.
175+
165176
## Release
166177
167178
*This section is addressed to maintainers.*

docs/source/docs/aggregation/aligned_mtl.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ Aligned-MTL
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.AlignedMTLWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

docs/source/docs/aggregation/bases.rst

Lines changed: 0 additions & 10 deletions
This file was deleted.

docs/source/docs/aggregation/cagrad.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ CAGrad
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.CAGradWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

docs/source/docs/aggregation/constant.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ Constant
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.ConstantWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

docs/source/docs/aggregation/dualproj.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ DualProj
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.DualProjWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

docs/source/docs/aggregation/imtl_g.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ IMTL-G
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.IMTLGWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

docs/source/docs/aggregation/index.rst

Lines changed: 13 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,146 +1,27 @@
1-
Aggregation
1+
aggregation
22
===========
33

4-
A mapping :math:`\mathcal A: \mathbb R^{m\times n} \to \mathbb R^n` reducing any matrix
5-
:math:`J \in \mathbb R^{m\times n}` into its aggregation :math:`\mathcal A(J) \in \mathbb R^n` is
6-
called an aggregator.
4+
.. automodule:: torchjd.aggregation
5+
:no-members:
76

8-
In the context of JD, the matrix to aggregate is a Jacobian whose rows are the gradients of the
9-
individual objectives. The aggregator is used to reduce this matrix into an update vector for the
10-
parameters of the model
7+
Abstract base classes
8+
---------------------
119

12-
In TorchJD, an aggregator is a class that inherits from the abstract class
13-
:doc:`Aggregator <bases>`. We provide the following list of aggregators from the literature:
14-
15-
.. role:: raw-html(raw)
16-
:format: html
17-
18-
.. |yes| replace:: :raw-html:`<center><font color="#28b528">✔</font></center>`
19-
.. |no| replace:: :raw-html:`<center><font color="#e63232">✘</font></center>`
20-
21-
.. list-table::
22-
:widths: 25 15 15 15
23-
:header-rows: 1
24-
25-
* - :doc:`Aggregator <bases>`
26-
- :ref:`Non-conflicting <Non-conflicting>`
27-
- :ref:`Linear under scaling <Linear under scaling>`
28-
- :ref:`Weighted <Weighted>`
29-
* - :doc:`UPGrad <upgrad>` (recommended)
30-
- |yes|
31-
- |yes|
32-
- |yes|
33-
* - :doc:`Aligned-MTL <aligned_mtl>`
34-
- |no|
35-
- |no|
36-
- |yes|
37-
* - :doc:`CAGrad <cagrad>`
38-
- |no|
39-
- |no|
40-
- |yes|
41-
* - :doc:`ConFIG <config>`
42-
- |no|
43-
- |yes|
44-
- |yes|
45-
* - :doc:`Constant <constant>`
46-
- |no|
47-
- |yes|
48-
- |yes|
49-
* - :doc:`DualProj <dualproj>`
50-
- |yes|
51-
- |no|
52-
- |yes|
53-
* - :doc:`GradDrop <graddrop>`
54-
- |no|
55-
- |no|
56-
- |no|
57-
* - :doc:`IMTL-G <imtl_g>`
58-
- |no|
59-
- |no|
60-
- |yes|
61-
* - :doc:`Krum <krum>`
62-
- |no|
63-
- |no|
64-
- |yes|
65-
* - :doc:`Mean <mean>`
66-
- |no|
67-
- |yes|
68-
- |yes|
69-
* - :doc:`MGDA <mgda>`
70-
- |yes|
71-
- |no|
72-
- |yes|
73-
* - :doc:`Nash-MTL <nash_mtl>`
74-
- |yes|
75-
- |no|
76-
- |yes|
77-
* - :doc:`PCGrad <pcgrad>`
78-
- |no|
79-
- |yes|
80-
- |yes|
81-
* - :doc:`Random <random>`
82-
- |no|
83-
- |yes|
84-
- |yes|
85-
* - :doc:`Sum <sum>`
86-
- |no|
87-
- |yes|
88-
- |yes|
89-
* - :doc:`Trimmed Mean <trimmed_mean>`
90-
- |no|
91-
- |no|
92-
- |no|
93-
94-
.. hint::
95-
This table is an adaptation of the one available in `Jacobian Descent For Multi-Objective
96-
Optimization <https://arxiv.org/pdf/2406.16232>`_. The paper provides precise justification of
97-
the properties in Section 2.2 as well as proofs in Appendix B.
98-
99-
.. _Non-conflicting:
100-
.. admonition::
101-
Non-conflicting
102-
103-
An aggregator :math:`\mathcal A: \mathbb R^{m\times n} \to \mathbb R^n` is said to be
104-
*non-conflicting* if for any :math:`J\in\mathbb R^{m\times n}`, :math:`J\cdot\mathcal A(J)` is a
105-
vector with only non-negative elements.
106-
107-
In other words, :math:`\mathcal A` is non-conflicting whenever the aggregation of any matrix has
108-
non-negative inner product with all rows of that matrix. In the context of JD, this ensures that
109-
no objective locally increases.
110-
111-
.. _Linear under scaling:
112-
.. admonition::
113-
Linear under scaling
114-
115-
An aggregator :math:`\mathcal A: \mathbb R^{m\times n} \to \mathbb R^n` is said to be
116-
*linear under scaling* if for any :math:`J\in\mathbb R^{m\times n}`, the mapping from any
117-
positive :math:`c\in\mathbb R^{n}` to :math:`\mathcal A(\operatorname{diag}(c)\cdot J)` is
118-
linear in :math:`c`.
119-
120-
In other words, :math:`\mathcal A` is linear under scaling whenever scaling a row of the matrix
121-
to aggregate scales its influence proportionally. In the context of JD, this ensures that even
122-
when the gradient norms are imbalanced, each gradient will contribute to the update
123-
proportionally to its norm.
124-
125-
.. _Weighted:
126-
.. admonition::
127-
Weighted
128-
129-
An aggregator :math:`\mathcal A: \mathbb R^{m\times n} \to \mathbb R^n` is said to be *weighted*
130-
if for any :math:`J\in\mathbb R^{m\times n}`, there exists a weight vector
131-
:math:`w\in\mathbb R^m` such that :math:`\mathcal A(J)=J^\top w`.
132-
133-
In other words, :math:`\mathcal A` is weighted whenever the aggregation of any matrix is always
134-
in the span of the rows of that matrix. This ensures a higher precision of the Taylor
135-
approximation that JD relies on.
10+
.. autoclass:: torchjd.aggregation.Aggregator
11+
:members:
12+
:undoc-members:
13+
:exclude-members: forward
13614

15+
.. autoclass:: torchjd.aggregation.Weighting
16+
:members:
17+
:undoc-members:
18+
:exclude-members: forward
13719

13820

13921
.. toctree::
14022
:hidden:
14123
:maxdepth: 1
14224

143-
bases.rst
14425
upgrad.rst
14526
aligned_mtl.rst
14627
cagrad.rst

docs/source/docs/aggregation/krum.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ Krum
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.KrumWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

0 commit comments

Comments
 (0)