Skip to content

Commit bc5e149

Browse files
committed
Reorder tests in test_rst.py in alphabetical order
1 parent 8102a9c commit bc5e149

File tree

1 file changed

+74
-74
lines changed

1 file changed

+74
-74
lines changed

tests/doc/test_rst.py

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,47 @@
55
"""
66

77

8+
def test_amp():
9+
import torch
10+
from torch.amp import GradScaler
11+
from torch.nn import Linear, MSELoss, ReLU, Sequential
12+
from torch.optim import SGD
13+
14+
from torchjd.aggregation import UPGrad
15+
from torchjd.autojac import mtl_backward
16+
17+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
18+
task1_module = Linear(3, 1)
19+
task2_module = Linear(3, 1)
20+
params = [
21+
*shared_module.parameters(),
22+
*task1_module.parameters(),
23+
*task2_module.parameters(),
24+
]
25+
scaler = GradScaler(device="cpu")
26+
loss_fn = MSELoss()
27+
optimizer = SGD(params, lr=0.1)
28+
aggregator = UPGrad()
29+
30+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
31+
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
32+
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
33+
34+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
35+
with torch.autocast(device_type="cpu", dtype=torch.float16):
36+
features = shared_module(input)
37+
output1 = task1_module(features)
38+
output2 = task2_module(features)
39+
loss1 = loss_fn(output1, target1)
40+
loss2 = loss_fn(output2, target2)
41+
42+
scaled_losses = scaler.scale([loss1, loss2])
43+
optimizer.zero_grad()
44+
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
45+
scaler.step(optimizer)
46+
scaler.update()
47+
48+
849
def test_basic_usage():
950
import torch
1051
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -111,43 +152,6 @@ def test_autogram():
111152
test_autogram()
112153

113154

114-
def test_mtl():
115-
import torch
116-
from torch.nn import Linear, MSELoss, ReLU, Sequential
117-
from torch.optim import SGD
118-
119-
from torchjd.aggregation import UPGrad
120-
from torchjd.autojac import mtl_backward
121-
122-
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
123-
task1_module = Linear(3, 1)
124-
task2_module = Linear(3, 1)
125-
params = [
126-
*shared_module.parameters(),
127-
*task1_module.parameters(),
128-
*task2_module.parameters(),
129-
]
130-
131-
loss_fn = MSELoss()
132-
optimizer = SGD(params, lr=0.1)
133-
aggregator = UPGrad()
134-
135-
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
136-
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
137-
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
138-
139-
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
140-
features = shared_module(input)
141-
output1 = task1_module(features)
142-
output2 = task2_module(features)
143-
loss1 = loss_fn(output1, target1)
144-
loss2 = loss_fn(output2, target2)
145-
146-
optimizer.zero_grad()
147-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
148-
optimizer.step()
149-
150-
151155
def test_lightning_integration():
152156
# Extra ----------------------------------------------------------------------------------------
153157
import logging
@@ -214,30 +218,6 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
214218
trainer.fit(model=model, train_dataloaders=train_loader)
215219

216220

217-
def test_rnn():
218-
import torch
219-
from torch.nn import RNN
220-
from torch.optim import SGD
221-
222-
from torchjd.aggregation import UPGrad
223-
from torchjd.autojac import backward
224-
225-
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
226-
optimizer = SGD(rnn.parameters(), lr=0.1)
227-
aggregator = UPGrad()
228-
229-
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
230-
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.
231-
232-
for input, target in zip(inputs, targets):
233-
output, _ = rnn(input) # output is of shape [5, 3, 20].
234-
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
235-
236-
optimizer.zero_grad()
237-
backward(losses, aggregator, parallel_chunk_size=1)
238-
optimizer.step()
239-
240-
241221
def test_monitoring():
242222
import torch
243223
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -290,9 +270,8 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
290270
optimizer.step()
291271

292272

293-
def test_amp():
273+
def test_mtl():
294274
import torch
295-
from torch.amp import GradScaler
296275
from torch.nn import Linear, MSELoss, ReLU, Sequential
297276
from torch.optim import SGD
298277

@@ -307,7 +286,7 @@ def test_amp():
307286
*task1_module.parameters(),
308287
*task2_module.parameters(),
309288
]
310-
scaler = GradScaler(device="cpu")
289+
311290
loss_fn = MSELoss()
312291
optimizer = SGD(params, lr=0.1)
313292
aggregator = UPGrad()
@@ -317,18 +296,15 @@ def test_amp():
317296
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
318297

319298
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
320-
with torch.autocast(device_type="cpu", dtype=torch.float16):
321-
features = shared_module(input)
322-
output1 = task1_module(features)
323-
output2 = task2_module(features)
324-
loss1 = loss_fn(output1, target1)
325-
loss2 = loss_fn(output2, target2)
299+
features = shared_module(input)
300+
output1 = task1_module(features)
301+
output2 = task2_module(features)
302+
loss1 = loss_fn(output1, target1)
303+
loss2 = loss_fn(output2, target2)
326304

327-
scaled_losses = scaler.scale([loss1, loss2])
328305
optimizer.zero_grad()
329-
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
330-
scaler.step(optimizer)
331-
scaler.update()
306+
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
307+
optimizer.step()
332308

333309

334310
def test_partial_jd():
@@ -362,3 +338,27 @@ def test_partial_jd():
362338
weights = weighting(gramian)
363339
losses.backward(weights)
364340
optimizer.step()
341+
342+
343+
def test_rnn():
344+
import torch
345+
from torch.nn import RNN
346+
from torch.optim import SGD
347+
348+
from torchjd.aggregation import UPGrad
349+
from torchjd.autojac import backward
350+
351+
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
352+
optimizer = SGD(rnn.parameters(), lr=0.1)
353+
aggregator = UPGrad()
354+
355+
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
356+
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.
357+
358+
for input, target in zip(inputs, targets):
359+
output, _ = rnn(input) # output is of shape [5, 3, 20].
360+
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
361+
362+
optimizer.zero_grad()
363+
backward(losses, aggregator, parallel_chunk_size=1)
364+
optimizer.step()

0 commit comments

Comments
 (0)