Skip to content

Commit 4e067a5

Browse files
committed
Add IWMTL example in README.md
1 parent 0d4ed2a commit 4e067a5

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,59 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
155155
optimizer.step()
156156
```
157157

158+
Lastly, you can even combine the two approaches by considering multiple tasks and each element of
159+
the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).
160+
161+
```python
162+
import torch
163+
from torch.nn import Linear, MSELoss, ReLU, Sequential
164+
from torch.optim import SGD
165+
166+
from torchjd.aggregation import Flattening, UPGradWeighting
167+
from torchjd.autogram import Engine
168+
169+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
170+
task1_module = Linear(3, 1)
171+
task2_module = Linear(3, 1)
172+
params = [
173+
*shared_module.parameters(),
174+
*task1_module.parameters(),
175+
*task2_module.parameters(),
176+
]
177+
178+
optimizer = SGD(params, lr=0.1)
179+
mse = MSELoss(reduction="none")
180+
weighting = Flattening(UPGradWeighting())
181+
engine = Engine(shared_module, batch_dim=0)
182+
183+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
184+
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
185+
task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task
186+
187+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
188+
features = shared_module(input) # shape: [16, 3]
189+
out1 = task1_module(features).squeeze(1) # shape: [16]
190+
out2 = task2_module(features).squeeze(1) # shape: [16]
191+
192+
# Compute the matrix of losses: one loss per element of the batch and per task
193+
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]
194+
195+
# Compute the gramian (inner products between pairs of gradients of the losses)
196+
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
197+
198+
# Obtain the weights that lead to no conflict between reweighted gradients
199+
weights = weighting(gramian) # shape: [16, 2]
200+
201+
optimizer.zero_grad()
202+
# Do the standard backward pass, but weighted using the obtained weights
203+
losses.backward(weights)
204+
optimizer.step()
205+
```
206+
207+
Note that here, because the losses are a matrix instead of a simple vector, we compute a
208+
*generalized Gramian* and we extract weights from it using a
209+
[GeneralizedWeighting](https://torchjd.org/docs/aggregation/index.html#torchjd.aggregation.GeneralizedWeighting).
210+
158211
More usage examples can be found [here](https://torchjd.org/stable/examples/).
159212

160213
## Supported Aggregators and Weightings

0 commit comments

Comments
 (0)