55"""
66
77
8+ def test_amp ():
9+ import torch
10+ from torch .amp import GradScaler
11+ from torch .nn import Linear , MSELoss , ReLU , Sequential
12+ from torch .optim import SGD
13+
14+ from torchjd .aggregation import UPGrad
15+ from torchjd .autojac import mtl_backward
16+
17+ shared_module = Sequential (Linear (10 , 5 ), ReLU (), Linear (5 , 3 ), ReLU ())
18+ task1_module = Linear (3 , 1 )
19+ task2_module = Linear (3 , 1 )
20+ params = [
21+ * shared_module .parameters (),
22+ * task1_module .parameters (),
23+ * task2_module .parameters (),
24+ ]
25+ scaler = GradScaler (device = "cpu" )
26+ loss_fn = MSELoss ()
27+ optimizer = SGD (params , lr = 0.1 )
28+ aggregator = UPGrad ()
29+
30+ inputs = torch .randn (8 , 16 , 10 ) # 8 batches of 16 random input vectors of length 10
31+ task1_targets = torch .randn (8 , 16 , 1 ) # 8 batches of 16 targets for the first task
32+ task2_targets = torch .randn (8 , 16 , 1 ) # 8 batches of 16 targets for the second task
33+
34+ for input , target1 , target2 in zip (inputs , task1_targets , task2_targets ):
35+ with torch .autocast (device_type = "cpu" , dtype = torch .float16 ):
36+ features = shared_module (input )
37+ output1 = task1_module (features )
38+ output2 = task2_module (features )
39+ loss1 = loss_fn (output1 , target1 )
40+ loss2 = loss_fn (output2 , target2 )
41+
42+ scaled_losses = scaler .scale ([loss1 , loss2 ])
43+ optimizer .zero_grad ()
44+ mtl_backward (losses = scaled_losses , features = features , aggregator = aggregator )
45+ scaler .step (optimizer )
46+ scaler .update ()
47+
48+
849def test_basic_usage ():
950 import torch
1051 from torch .nn import Linear , MSELoss , ReLU , Sequential
@@ -111,43 +152,6 @@ def test_autogram():
111152 test_autogram ()
112153
113154
114- def test_mtl ():
115- import torch
116- from torch .nn import Linear , MSELoss , ReLU , Sequential
117- from torch .optim import SGD
118-
119- from torchjd .aggregation import UPGrad
120- from torchjd .autojac import mtl_backward
121-
122- shared_module = Sequential (Linear (10 , 5 ), ReLU (), Linear (5 , 3 ), ReLU ())
123- task1_module = Linear (3 , 1 )
124- task2_module = Linear (3 , 1 )
125- params = [
126- * shared_module .parameters (),
127- * task1_module .parameters (),
128- * task2_module .parameters (),
129- ]
130-
131- loss_fn = MSELoss ()
132- optimizer = SGD (params , lr = 0.1 )
133- aggregator = UPGrad ()
134-
135- inputs = torch .randn (8 , 16 , 10 ) # 8 batches of 16 random input vectors of length 10
136- task1_targets = torch .randn (8 , 16 , 1 ) # 8 batches of 16 targets for the first task
137- task2_targets = torch .randn (8 , 16 , 1 ) # 8 batches of 16 targets for the second task
138-
139- for input , target1 , target2 in zip (inputs , task1_targets , task2_targets ):
140- features = shared_module (input )
141- output1 = task1_module (features )
142- output2 = task2_module (features )
143- loss1 = loss_fn (output1 , target1 )
144- loss2 = loss_fn (output2 , target2 )
145-
146- optimizer .zero_grad ()
147- mtl_backward (losses = [loss1 , loss2 ], features = features , aggregator = aggregator )
148- optimizer .step ()
149-
150-
151155def test_lightning_integration ():
152156 # Extra ----------------------------------------------------------------------------------------
153157 import logging
@@ -214,30 +218,6 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
214218 trainer .fit (model = model , train_dataloaders = train_loader )
215219
216220
217- def test_rnn ():
218- import torch
219- from torch .nn import RNN
220- from torch .optim import SGD
221-
222- from torchjd .aggregation import UPGrad
223- from torchjd .autojac import backward
224-
225- rnn = RNN (input_size = 10 , hidden_size = 20 , num_layers = 2 )
226- optimizer = SGD (rnn .parameters (), lr = 0.1 )
227- aggregator = UPGrad ()
228-
229- inputs = torch .randn (8 , 5 , 3 , 10 ) # 8 batches of 3 sequences of length 5 and of dim 10.
230- targets = torch .randn (8 , 5 , 3 , 20 ) # 8 batches of 3 sequences of length 5 and of dim 20.
231-
232- for input , target in zip (inputs , targets ):
233- output , _ = rnn (input ) # output is of shape [5, 3, 20].
234- losses = ((output - target ) ** 2 ).mean (dim = [1 , 2 ]) # 1 loss per sequence element.
235-
236- optimizer .zero_grad ()
237- backward (losses , aggregator , parallel_chunk_size = 1 )
238- optimizer .step ()
239-
240-
241221def test_monitoring ():
242222 import torch
243223 from torch .nn import Linear , MSELoss , ReLU , Sequential
@@ -290,9 +270,8 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
290270 optimizer .step ()
291271
292272
293- def test_amp ():
273+ def test_mtl ():
294274 import torch
295- from torch .amp import GradScaler
296275 from torch .nn import Linear , MSELoss , ReLU , Sequential
297276 from torch .optim import SGD
298277
@@ -307,7 +286,7 @@ def test_amp():
307286 * task1_module .parameters (),
308287 * task2_module .parameters (),
309288 ]
310- scaler = GradScaler ( device = "cpu" )
289+
311290 loss_fn = MSELoss ()
312291 optimizer = SGD (params , lr = 0.1 )
313292 aggregator = UPGrad ()
@@ -317,18 +296,15 @@ def test_amp():
317296 task2_targets = torch .randn (8 , 16 , 1 ) # 8 batches of 16 targets for the second task
318297
319298 for input , target1 , target2 in zip (inputs , task1_targets , task2_targets ):
320- with torch .autocast (device_type = "cpu" , dtype = torch .float16 ):
321- features = shared_module (input )
322- output1 = task1_module (features )
323- output2 = task2_module (features )
324- loss1 = loss_fn (output1 , target1 )
325- loss2 = loss_fn (output2 , target2 )
299+ features = shared_module (input )
300+ output1 = task1_module (features )
301+ output2 = task2_module (features )
302+ loss1 = loss_fn (output1 , target1 )
303+ loss2 = loss_fn (output2 , target2 )
326304
327- scaled_losses = scaler .scale ([loss1 , loss2 ])
328305 optimizer .zero_grad ()
329- mtl_backward (losses = scaled_losses , features = features , aggregator = aggregator )
330- scaler .step (optimizer )
331- scaler .update ()
306+ mtl_backward (losses = [loss1 , loss2 ], features = features , aggregator = aggregator )
307+ optimizer .step ()
332308
333309
334310def test_partial_jd ():
@@ -362,3 +338,27 @@ def test_partial_jd():
362338 weights = weighting (gramian )
363339 losses .backward (weights )
364340 optimizer .step ()
341+
342+
343+ def test_rnn ():
344+ import torch
345+ from torch .nn import RNN
346+ from torch .optim import SGD
347+
348+ from torchjd .aggregation import UPGrad
349+ from torchjd .autojac import backward
350+
351+ rnn = RNN (input_size = 10 , hidden_size = 20 , num_layers = 2 )
352+ optimizer = SGD (rnn .parameters (), lr = 0.1 )
353+ aggregator = UPGrad ()
354+
355+ inputs = torch .randn (8 , 5 , 3 , 10 ) # 8 batches of 3 sequences of length 5 and of dim 10.
356+ targets = torch .randn (8 , 5 , 3 , 20 ) # 8 batches of 3 sequences of length 5 and of dim 20.
357+
358+ for input , target in zip (inputs , targets ):
359+ output , _ = rnn (input ) # output is of shape [5, 3, 20].
360+ losses = ((output - target ) ** 2 ).mean (dim = [1 , 2 ]) # 1 loss per sequence element.
361+
362+ optimizer .zero_grad ()
363+ backward (losses , aggregator , parallel_chunk_size = 1 )
364+ optimizer .step ()
0 commit comments