|
| 1 | +Instance-Wise Multi-Task Learning (IWMTL) |
| 2 | +========================================= |
| 3 | + |
| 4 | +When training a model with multiple tasks, the gradients of the individual tasks are likely to |
| 5 | +conflict. This is particularly true when looking at the individual (per-sample) gradients. |
| 6 | +The :doc:`autogram engine <../docs/autogram/engine>` can be used to efficiently compute the Gramian |
| 7 | +of the Jacobian of the matrix of per-sample and per-task losses. Weights can then be extracted from |
| 8 | +this Gramian to reweight the gradients and resolve conflict entirely. |
| 9 | + |
| 10 | +The following example shows how to do that. |
| 11 | + |
| 12 | +.. code-block:: python |
| 13 | + :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42 |
| 14 | +
|
| 15 | + import torch |
| 16 | + from torch.nn import Linear, MSELoss, ReLU, Sequential |
| 17 | + from torch.optim import SGD |
| 18 | +
|
| 19 | + from torchjd.aggregation import Flattening, UPGradWeighting |
| 20 | + from torchjd.autogram import Engine |
| 21 | +
|
| 22 | + shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) |
| 23 | + task1_module = Linear(3, 1) |
| 24 | + task2_module = Linear(3, 1) |
| 25 | + params = [ |
| 26 | + *shared_module.parameters(), |
| 27 | + *task1_module.parameters(), |
| 28 | + *task2_module.parameters(), |
| 29 | + ] |
| 30 | +
|
| 31 | + optimizer = SGD(params, lr=0.1) |
| 32 | + mse = MSELoss(reduction="none") |
| 33 | + weighting = Flattening(UPGradWeighting()) |
| 34 | + engine = Engine(shared_module.modules(), batch_dim=0) |
| 35 | +
|
| 36 | + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 |
| 37 | + task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task |
| 38 | + task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task |
| 39 | +
|
| 40 | + for input, target1, target2 in zip(inputs, task1_targets, task2_targets): |
| 41 | + features = shared_module(input) # shape: [16, 3] |
| 42 | + out1 = task1_module(features).squeeze(1) # shape: [16] |
| 43 | + out2 = task2_module(features).squeeze(1) # shape: [16] |
| 44 | +
|
| 45 | + # Compute the matrix of losses: one loss per element of the batch and per task |
| 46 | + losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] |
| 47 | +
|
| 48 | + # Compute the gramian (inner products between pairs of gradients of the losses) |
| 49 | + gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] |
| 50 | +
|
| 51 | + # Obtain the weights that lead to no conflict between reweighted gradients |
| 52 | + weights = weighting(gramian) # shape: [16, 2] |
| 53 | +
|
| 54 | + optimizer.zero_grad() |
| 55 | + # Do the standard backward pass, but weighted using the obtained weights |
| 56 | + losses.backward(weights) |
| 57 | + optimizer.step() |
| 58 | +
|
| 59 | +.. note:: |
| 60 | + In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a |
| 61 | + 4D tensor rather than a matrix, and a |
| 62 | + :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as |
| 63 | + :class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of |
| 64 | + weights from it. More information about ``GeneralizedWeighting`` can be found in the |
| 65 | + :doc:`../../docs/aggregation/index` page. |
0 commit comments