@@ -155,6 +155,59 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
155155 optimizer.step()
156156```
157157
158+ Lastly, you can even combine the two approaches by considering multiple tasks and each element of
159+ the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).
160+
161+ ``` python
162+ import torch
163+ from torch.nn import Linear, MSELoss, ReLU, Sequential
164+ from torch.optim import SGD
165+
166+ from torchjd.aggregation import Flattening, UPGradWeighting
167+ from torchjd.autogram import Engine
168+
169+ shared_module = Sequential(Linear(10 , 5 ), ReLU(), Linear(5 , 3 ), ReLU())
170+ task1_module = Linear(3 , 1 )
171+ task2_module = Linear(3 , 1 )
172+ params = [
173+ * shared_module.parameters(),
174+ * task1_module.parameters(),
175+ * task2_module.parameters(),
176+ ]
177+
178+ optimizer = SGD(params, lr = 0.1 )
179+ mse = MSELoss(reduction = " none" )
180+ weighting = Flattening(UPGradWeighting())
181+ engine = Engine(shared_module, batch_dim = 0 )
182+
183+ inputs = torch.randn(8 , 16 , 10 ) # 8 batches of 16 random input vectors of length 10
184+ task1_targets = torch.randn(8 , 16 ) # 8 batches of 16 targets for the first task
185+ task2_targets = torch.randn(8 , 16 ) # 8 batches of 16 targets for the second task
186+
187+ for input , target1, target2 in zip (inputs, task1_targets, task2_targets):
188+ features = shared_module(input ) # shape: [16, 3]
189+ out1 = task1_module(features).squeeze(1 ) # shape: [16]
190+ out2 = task2_module(features).squeeze(1 ) # shape: [16]
191+
192+ # Compute the matrix of losses: one loss per element of the batch and per task
193+ losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim = 1 ) # shape: [16, 2]
194+
195+ # Compute the gramian (inner products between pairs of gradients of the losses)
196+ gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
197+
198+ # Obtain the weights that lead to no conflict between reweighted gradients
199+ weights = weighting(gramian) # shape: [16, 2]
200+
201+ optimizer.zero_grad()
202+ # Do the standard backward pass, but weighted using the obtained weights
203+ losses.backward(weights)
204+ optimizer.step()
205+ ```
206+
207+ Note that here, because the losses are a matrix instead of a simple vector, we compute a
208+ * generalized Gramian* and we extract weights from it using a
209+ [ GeneralizedWeighting] ( https://torchjd.org/docs/aggregation/index.html#torchjd.aggregation.GeneralizedWeighting ) .
210+
158211More usage examples can be found [ here] ( https://torchjd.org/stable/examples/ ) .
159212
160213## Supported Aggregators and Weightings
0 commit comments