Skip to content

Commit 0f7cd8c

Browse files
committed
Update README.md for autogram
1 parent 7937223 commit 0f7cd8c

File tree

1 file changed

+73
-41
lines changed

1 file changed

+73
-41
lines changed

README.md

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,23 @@ Some aggregators may have additional dependencies. Please refer to the
5555
[installation documentation](https://torchjd.org/stable/installation) for them.
5656

5757
## Usage
58-
The main way to use TorchJD is to replace the usual call to `loss.backward()` by a call to
59-
`torchjd.backward` or `torchjd.mtl_backward`, depending on the use-case.
58+
There are two main ways to use TorchJD. The first one is to replace the usual call to
59+
`loss.backward()` by a call to
60+
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) or
61+
[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/), depending
62+
on the use-case. This will compute the Jacobian of the vector of losses with respect to the model
63+
parameters, and aggregate it with the specified
64+
[`Aggregator`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Aggregator).
65+
Whenever you want to optimize the vector of per-sample losses, you should rather use the
66+
[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine.html). Instead of
67+
computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a
68+
memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from
69+
this Gramian, using a
70+
[`Weighting`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Weighting),
71+
and used to combine the losses of the batch. Assuming each element of the batch is
72+
processed independently from the others, this approach is equivalent to
73+
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) while being
74+
generally much faster due to the lower memory usage.
6075

6176
The following example shows how to use TorchJD to train a multi-task model with Jacobian descent,
6277
using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
@@ -66,7 +81,7 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
6681
from torch.nn import Linear, MSELoss, ReLU, Sequential
6782
from torch.optim import SGD
6883

69-
+ from torchjd import mtl_backward
84+
+ from torchjd.autojac import mtl_backward
7085
+ from torchjd.aggregation import UPGrad
7186

7287
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
@@ -104,49 +119,66 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
104119
> In this example, the Jacobian is only with respect to the shared parameters. The task-specific
105120
> parameters are simply updated via the gradient of their task’s loss with respect to them.
106121
122+
The following example shows how to use TorchJD to minimize the vector of per-instance losses with
123+
Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
124+
125+
```diff
126+
import torch
127+
from torch.nn import Linear, MSELoss, ReLU, Sequential
128+
from torch.optim import SGD
129+
130+
+ from torchjd.autogram import Engine
131+
+ from torchjd.aggregation import UPGradWeighting
132+
133+
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU())
134+
135+
- loss_fn = MSELoss()
136+
+ loss_fn = MSELoss(reduction="none")
137+
optimizer = SGD(model.parameters(), lr=0.1)
138+
139+
+ weighting = UPGradWeighting()
140+
+ engine = Engine(model.modules())
141+
142+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
143+
targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
144+
145+
for input, target in zip(inputs, targets):
146+
output = model(input).squeeze(dim=1) # shape [16]
147+
- loss = loss_fn(output, target) # shape [1]
148+
+ losses = loss_fn(output, target) # shape [16]
149+
150+
optimizer.zero_grad()
151+
- loss.backward()
152+
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
153+
+ weights = weighting(gramian) # shape: [16]
154+
+ losses.backward(weights)
155+
optimizer.step()
156+
```
157+
107158
More usage examples can be found [here](https://torchjd.org/stable/examples/).
108159

109-
## Supported Aggregators
160+
## Supported Aggregators and Weightings
110161
TorchJD provides many existing aggregators from the literature, listed in the following table.
111162

112163
<!-- recommended aggregators first, then alphabetical order -->
113-
| Aggregator | Publication |
114-
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
115-
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) (recommended) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
116-
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl/) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
117-
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad/) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
118-
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config/) | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
119-
| [Constant](https://torchjd.org/stable/docs/aggregation/constant/) | - |
120-
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj/) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
121-
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop/) | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
122-
| [IMTL-G](https://torchjd.org/stable/docs/aggregation/imtl_g/) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
123-
| [Krum](https://torchjd.org/stable/docs/aggregation/krum/) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
124-
| [Mean](https://torchjd.org/stable/docs/aggregation/mean/) | - |
125-
| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda/) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) |
126-
| [Nash-MTL](https://torchjd.org/stable/docs/aggregation/nash_mtl/) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) |
127-
| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad/) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
128-
| [Random](https://torchjd.org/stable/docs/aggregation/random/) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) |
129-
| [Sum](https://torchjd.org/stable/docs/aggregation/sum/) | - |
130-
| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean/) | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) |
131-
132-
The following example shows how to instantiate
133-
[UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) and aggregate a simple matrix `J` with
134-
it.
135-
```python
136-
from torch import tensor
137-
from torchjd.aggregation import UPGrad
138-
139-
A = UPGrad()
140-
J = tensor([[-4., 1., 1.], [6., 1., 1.]])
141-
142-
A(J)
143-
# Output: tensor([0.2929, 1.9004, 1.9004])
144-
```
145-
146-
> [!TIP]
147-
> When using TorchJD, you generally don't have to use aggregators directly. You simply instantiate
148-
> one and pass it to the backward function (`torchjd.backward` or `torchjd.mtl_backward`), which
149-
> will in turn apply it to the Jacobian matrix that it will compute.
164+
| Aggregator | Weighting | Publication |
165+
|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
166+
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad.html#torchjd.aggregation.UPGrad) (recommended) | [UPGradWeighting](https://torchjd.org/stable/docs/aggregation/upgrad#torchjd.aggregation.UPGradWeighting) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
167+
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTL) | [AlignedMTLWeighting](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTLWeighting) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
168+
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGrad) | [CAGradWeighting](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGradWeighting) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
169+
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config#torchjd.aggregation.ConFIG) | - | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
170+
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
171+
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
172+
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
173+
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
174+
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
175+
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |
176+
| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDA) | [MGDAWeighting](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDAWeighting) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) |
177+
| [NashMTL](https://torchjd.org/stable/docs/aggregation/nash_mtl#torchjd.aggregation.NashMTL) | - | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) |
178+
| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGrad) | [PCGradWeighting](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGradWeighting) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
179+
| [Random](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.Random) | [RandomWeighting](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.RandomWeighting) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) |
180+
| [Sum](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.Sum) | [SumWeighting](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.SumWeighting) | - |
181+
| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean#torchjd.aggregation.TrimmedMean) | - | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) |
150182

151183
## Contribution
152184
Please read the [Contribution page](CONTRIBUTING.md).

0 commit comments

Comments
 (0)