Skip to content

Commit 6e98e60

Browse files
authored
Merge branch 'main' into autogram-readme
2 parents d392656 + bab9a50 commit 6e98e60

28 files changed

+1416
-463
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ changes that do not affect the user.
1212

1313
- Added the `autogram` package, with the `autogram.Engine`. This is an implementation of Algorithm 3
1414
from [Jacobian Descent for Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232),
15-
optimized for batched computations, as in IWRM.
15+
optimized for batched computations, as in IWRM. Generalized Gramians can also be obtained by using
16+
the autogram engine on a tensor of losses of arbitrary shape.
1617
- For all `Aggregator`s based on the weighting of the Gramian of the Jacobian, made their
1718
`Weighting` class public. It can be used directly on a Gramian (computed via the
1819
`autogram.Engine`) to extract some weights. The list of new public classes is:
@@ -29,11 +30,16 @@ changes that do not affect the user.
2930
- `PCGradWeighting`
3031
- `RandomWeighting`
3132
- `SumWeighting`
33+
- Added `GeneralizedWeighting` (base class) and `Flattening` (implementation) to extract tensors of
34+
weights from generalized Gramians.
3235
- Added usage example for IWRM with autogram.
3336
- Added usage example for IWRM with partial autogram.
37+
- Added usage example for IWMTL with autogram.
3438

3539
### Changed
3640

41+
- Removed an unnecessary internal reshape when computing Jacobians. This should have no effect but a
42+
slight performance improvement in `autojac`.
3743
- Revamped documentation.
3844
- Made `backward` and `mtl_backward` importable from `torchjd.autojac` (like it was prior to 0.7.0).
3945
- Deprecated importing `backward` and `mtl_backward` from `torchjd` directly.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
:hide-toc:
2+
3+
Flattening
4+
==========
5+
6+
.. autoclass:: torchjd.aggregation.Flattening
7+
:members:
8+
:undoc-members:
9+
:exclude-members: forward

docs/source/docs/aggregation/index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ Abstract base classes
1717
:undoc-members:
1818
:exclude-members: forward
1919

20+
.. autoclass:: torchjd.aggregation.GeneralizedWeighting
21+
:members:
22+
:undoc-members:
23+
:exclude-members: forward
24+
2025

2126
.. toctree::
2227
:hidden:
@@ -28,6 +33,7 @@ Abstract base classes
2833
config.rst
2934
constant.rst
3035
dualproj.rst
36+
flattening.rst
3137
graddrop.rst
3238
imtl_g.rst
3339
krum.rst

docs/source/examples/index.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ This section contains some usage examples for TorchJD.
1818
- :doc:`Multi-Task Learning (MTL) <mtl>` provides an example of multi-task learning where Jacobian
1919
descent is used to optimize the vector of per-task losses of a multi-task model, using the
2020
dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`.
21+
- :doc:`Instance-Wise Multi-Task Learning (IWMTL) <iwmtl>` shows how to combine multi-task learning
22+
with instance-wise risk minimization: one loss per task and per element of the batch, using the
23+
:doc:`autogram.Engine <../docs/autogram/engine>` and a :doc:`GeneralizedWeighting
24+
<../docs/aggregation/index>`.
2125
- :doc:`Recurrent Neural Network (RNN) <rnn>` shows how to apply Jacobian descent to RNN training,
2226
with one loss per output sequence element.
2327
- :doc:`Monitoring Aggregations <monitoring>` shows how to monitor the aggregation performed by the
@@ -34,6 +38,7 @@ This section contains some usage examples for TorchJD.
3438
iwrm.rst
3539
partial_jd.rst
3640
mtl.rst
41+
iwmtl.rst
3742
rnn.rst
3843
monitoring.rst
3944
lightning_integration.rst

docs/source/examples/iwmtl.rst

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
Instance-Wise Multi-Task Learning (IWMTL)
2+
=========================================
3+
4+
When training a model with multiple tasks, the gradients of the individual tasks are likely to
5+
conflict. This is particularly true when looking at the individual (per-sample) gradients.
6+
The :doc:`autogram engine <../docs/autogram/engine>` can be used to efficiently compute the Gramian
7+
of the Jacobian of the matrix of per-sample and per-task losses. Weights can then be extracted from
8+
this Gramian to reweight the gradients and resolve conflict entirely.
9+
10+
The following example shows how to do that.
11+
12+
.. code-block:: python
13+
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
14+
15+
import torch
16+
from torch.nn import Linear, MSELoss, ReLU, Sequential
17+
from torch.optim import SGD
18+
19+
from torchjd.aggregation import Flattening, UPGradWeighting
20+
from torchjd.autogram import Engine
21+
22+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
23+
task1_module = Linear(3, 1)
24+
task2_module = Linear(3, 1)
25+
params = [
26+
*shared_module.parameters(),
27+
*task1_module.parameters(),
28+
*task2_module.parameters(),
29+
]
30+
31+
optimizer = SGD(params, lr=0.1)
32+
mse = MSELoss(reduction="none")
33+
weighting = Flattening(UPGradWeighting())
34+
engine = Engine(shared_module.modules(), batch_dim=0)
35+
36+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
37+
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
38+
task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task
39+
40+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
41+
features = shared_module(input) # shape: [16, 3]
42+
out1 = task1_module(features).squeeze(1) # shape: [16]
43+
out2 = task2_module(features).squeeze(1) # shape: [16]
44+
45+
# Compute the matrix of losses: one loss per element of the batch and per task
46+
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]
47+
48+
# Compute the gramian (inner products between pairs of gradients of the losses)
49+
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
50+
51+
# Obtain the weights that lead to no conflict between reweighted gradients
52+
weights = weighting(gramian) # shape: [16, 2]
53+
54+
optimizer.zero_grad()
55+
# Do the standard backward pass, but weighted using the obtained weights
56+
losses.backward(weights)
57+
optimizer.step()
58+
59+
.. note::
60+
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a
61+
4D tensor rather than a matrix, and a
62+
:class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as
63+
:class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of
64+
weights from it. More information about ``GeneralizedWeighting`` can be found in the
65+
:doc:`../../docs/aggregation/index` page.

docs/source/examples/iwrm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
129129
params = model.parameters()
130130
optimizer = SGD(params, lr=0.1)
131131
weighting = UPGradWeighting()
132-
engine = Engine(model.modules())
132+
engine = Engine(model.modules(), batch_dim=0)
133133
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]

docs/source/examples/partial_jd.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
3333
3434
# Create the autogram engine that will compute the Gramian of the
3535
# Jacobian with respect to the two last Linear layers' parameters.
36-
engine = Engine(model[2:].modules())
36+
engine = Engine(model[2:].modules(), batch_dim=0)
3737
3838
params = model.parameters()
3939
optimizer = SGD(params, lr=0.1)

docs/source/index.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ per-task losses has to be minimized. To start using TorchJD for multi-task learn
3838
Another more interesting application is to consider separately the loss of each element in the
3939
batch. This is what we define as :doc:`Instance-Wise Risk Minimization <examples/iwrm>` (IWRM).
4040

41-
For IWRM, in many cases, there exists an algorithm that is both equivalent to Jacobian descent, and
42-
much more efficient. This algorithm, called Gramian-based Jacobian descent, consists in computing
41+
The Gramian-based Jacobian descent algorithm provides a very efficient alternative way of
42+
performing Jacobian descent. It consists in computing
4343
the Gramian of the Jacobian iteratively during the backward pass (without ever storing the full
4444
Jacobian in memory), weighting the losses using the information of the Gramian, and then computing
4545
the gradient of the obtained weighted loss. The iterative computation of the Gramian corresponds to
@@ -48,6 +48,11 @@ Algorithm 3 of
4848
documentation and usage example of this algorithm is provided in
4949
:doc:`autogram.Engine <docs/autogram/engine>`.
5050

51+
The original usage of the autogram engine is to compute the Gramian of the Jacobian very efficiently
52+
for :doc:`IWRM <examples/iwrm>`. Another direct application is when considering one loss per element
53+
of the batch and per task, in the context of multi-task learning. We call this
54+
:doc:`Instance-Wise Risk Multi-Task Learning <examples/iwmtl>` (IWMTL).
55+
5156
TorchJD is open-source, under MIT License. The source code is available on
5257
`GitHub <https://github.com/TorchJD/torchjd>`_.
5358

src/torchjd/aggregation/__init__.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
r"""
22
When doing Jacobian descent, the Jacobian matrix has to be aggregated into a vector to store in the
33
``.grad`` fields of the model parameters. The
4-
The :class:`~torchjd.aggregation._aggregator_bases.Aggregator` is responsible for these
5-
aggregations.
4+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` is responsible for these aggregations.
65
76
When using the :doc:`autogram <../autogram/index>` engine, we rather need to extract a vector
87
of weights from the Gramian of the Jacobian. The
@@ -20,7 +19,8 @@
2019
:class:`Aggregators <torchjd.aggregation._aggregator_bases.Aggregator>` and :class:`Weightings
2120
<torchjd.aggregation._weighting_bases.Weighting>` are callables that take a Jacobian matrix or a
2221
Gramian matrix as inputs, respectively. The following example shows how to use UPGrad to either
23-
aggregate a Jacobian or obtain the weights from the Gramian of the Jacobian.
22+
aggregate a Jacobian (of shape ``[m, n]``, where ``m`` is the number of objectives and ``n`` is the
23+
number of parameters), or obtain the weights from the Gramian of the Jacobian (of shape ``[m, m]``).
2424
2525
>>> from torch import tensor
2626
>>> from torchjd.aggregation import UPGrad, UPGradWeighting
@@ -35,13 +35,36 @@
3535
>>> weights = weighting(gramian)
3636
>>> weights
3737
tensor([1.1109, 0.7894])
38+
39+
When dealing with a more general tensor of objectives, of shape ``[m_1, ..., m_k]`` (i.e. not
40+
necessarily a simple vector), the Jacobian will be of shape ``[m_1, ..., m_k, n]``, and its Gramian
41+
will be called a `generalized Gramian`, of shape ``[m_1, ..., m_k, m_k, ..., m_1]``. One can use a
42+
:class:`GeneralizedWeighting<torchjd.aggregation._weighting_bases.GeneralizedWeighting>` to extract
43+
a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest
44+
:class:`GeneralizedWeighting<torchjd.aggregation._weighting_bases.GeneralizedWeighting>` is
45+
:class:`Flattening<torchjd.aggregation._flattening.Flattening>`: it simply "flattens" the
46+
generalized Gramian into a square Gramian matrix (of shape ``[m_1 * ... * m_k, m_1 * ... * m_k]``),
47+
applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of
48+
weights.
49+
50+
>>> from torch import ones
51+
>>> from torchjd.aggregation import Flattening, UPGradWeighting
52+
>>>
53+
>>> weighting = Flattening(UPGradWeighting())
54+
>>> # Generate a generalized Gramian filled with ones, for the sake of the example
55+
>>> generalized_gramian = ones((2, 3, 3, 2))
56+
>>> weights = weighting(generalized_gramian)
57+
>>> weights
58+
tensor([[0.1667, 0.1667, 0.1667],
59+
[0.1667, 0.1667, 0.1667]])
3860
"""
3961

4062
from ._aggregator_bases import Aggregator
4163
from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting
4264
from ._config import ConFIG
4365
from ._constant import Constant, ConstantWeighting
4466
from ._dualproj import DualProj, DualProjWeighting
67+
from ._flattening import Flattening
4568
from ._graddrop import GradDrop
4669
from ._imtl_g import IMTLG, IMTLGWeighting
4770
from ._krum import Krum, KrumWeighting
@@ -55,7 +78,7 @@
5578
from ._utils.check_dependencies import (
5679
OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError,
5780
)
58-
from ._weighting_bases import Weighting
81+
from ._weighting_bases import GeneralizedWeighting, Weighting
5982

6083
try:
6184
from ._cagrad import CAGrad, CAGradWeighting
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from math import prod
2+
3+
from torch import Tensor
4+
5+
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting
6+
from torchjd.autogram._gramian_utils import reshape_gramian
7+
8+
9+
class Flattening(GeneralizedWeighting):
10+
"""
11+
:class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting` flattening the generalized
12+
Gramian into a square matrix, extracting a vector of weights from it using a
13+
:class:`~torchjd.aggregation._weighting_bases.Weighting`, and returning the reshaped tensor of
14+
weights.
15+
16+
For instance, when applied to a generalized Gramian of shape ``[2, 3, 3, 2]``, it would flatten
17+
it into a square Gramian matrix of shape ``[6, 6]``, apply the weighting on it to get a vector
18+
of weights of shape ``[6]``, and then return this vector reshaped into a matrix of shape
19+
``[2, 3]``.
20+
21+
:param weighting: The weighting to apply to the Gramian matrix.
22+
"""
23+
24+
def __init__(self, weighting: Weighting[PSDMatrix]):
25+
super().__init__()
26+
self.weighting = weighting
27+
28+
def forward(self, generalized_gramian: Tensor) -> Tensor:
29+
k = generalized_gramian.ndim // 2
30+
shape = generalized_gramian.shape[:k]
31+
m = prod(shape)
32+
square_gramian = reshape_gramian(generalized_gramian, [m])
33+
weights_vector = self.weighting(square_gramian)
34+
weights = weights_vector.reshape(shape)
35+
return weights

0 commit comments

Comments
 (0)