Skip to content

Commit 7d5ec94

Browse files
authored
test(autogram): Improve code quality (#464)
* Extract rng forking into contexts.py * Make _forward_pass do rng forking * Make _forward_pass take reduction parameter * Make forward_pass public * Use forward_pass in test_engine.py, stop reseeding (it's now done by forward_pass) * Make zipping strict in make_mse_loss_fn * Stop requiring params in autograd_gramian_forward_backward * Improve parameter order of autogram_forward_backward * Rename some variables * Factorize input and target creation into make_inputs_and_targets * Reorder some code
1 parent 0b7c3f6 commit 7d5ec94

File tree

6 files changed

+95
-127
lines changed

6 files changed

+95
-127
lines changed

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
NoFreeParam,
1515
SqueezeNet,
1616
WithTransformerLarge,
17-
get_in_out_shapes,
1817
)
1918
from utils.forward_backwards import (
2019
autograd_forward_backward,
@@ -23,7 +22,7 @@
2322
autojac_forward_backward,
2423
make_mse_loss_fn,
2524
)
26-
from utils.tensors import make_tensors
25+
from utils.tensors import make_inputs_and_targets
2726

2827
from torchjd.aggregation import Mean
2928
from torchjd.autogram import Engine
@@ -43,9 +42,7 @@
4342

4443
def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int):
4544
model = factory()
46-
input_shapes, output_shapes = get_in_out_shapes(model)
47-
inputs = make_tensors(batch_size, input_shapes)
48-
targets = make_tensors(batch_size, output_shapes)
45+
inputs, targets = make_inputs_and_targets(model, batch_size)
4946
loss_fn = make_mse_loss_fn(targets)
5047

5148
A = Mean()
@@ -64,7 +61,7 @@ def init_fn_autograd():
6461
fn_autograd()
6562

6663
def fn_autograd_gramian():
67-
autograd_gramian_forward_backward(model, inputs, list(model.parameters()), loss_fn, W)
64+
autograd_gramian_forward_backward(model, inputs, loss_fn, W)
6865

6966
def init_fn_autograd_gramian():
7067
torch.cuda.empty_cache()
@@ -80,7 +77,7 @@ def init_fn_autojac():
8077
fn_autojac()
8178

8279
def fn_autogram():
83-
autogram_forward_backward(model, engine, W, inputs, loss_fn)
80+
autogram_forward_backward(model, inputs, loss_fn, engine, W)
8481

8582
def init_fn_autogram():
8683
torch.cuda.empty_cache()

tests/unit/autogram/test_engine.py

Lines changed: 41 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,21 @@
6161
WithSomeFrozenModule,
6262
WithTransformer,
6363
WithTransformerLarge,
64-
get_in_out_shapes,
6564
)
6665
from utils.dict_assertions import assert_tensor_dicts_are_close
6766
from utils.forward_backwards import (
6867
autograd_forward_backward,
6968
autogram_forward_backward,
7069
compute_gramian,
7170
compute_gramian_with_autograd,
71+
forward_pass,
7272
make_mse_loss_fn,
7373
reduce_to_first_tensor,
7474
reduce_to_matrix,
7575
reduce_to_scalar,
7676
reduce_to_vector,
7777
)
78-
from utils.tensors import make_tensors, ones_, randn_, zeros_
78+
from utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_
7979

8080
from torchjd.aggregation import UPGradWeighting
8181
from torchjd.autogram._engine import Engine
@@ -144,22 +144,14 @@ def _assert_gramian_is_equivalent_to_autograd(
144144
factory: ModuleFactory, batch_size: int, batch_dim: int | None
145145
):
146146
model_autograd, model_autogram = factory(), factory()
147-
input_shapes, output_shapes = get_in_out_shapes(model_autograd)
148-
149147
engine = Engine(model_autogram, batch_dim=batch_dim)
150-
151-
inputs = make_tensors(batch_size, input_shapes)
152-
targets = make_tensors(batch_size, output_shapes)
148+
inputs, targets = make_inputs_and_targets(model_autograd, batch_size)
153149
loss_fn = make_mse_loss_fn(targets)
154150

155-
torch.random.manual_seed(0) # Fix randomness for random models
156-
output = model_autograd(inputs)
157-
losses = reduce_to_vector(loss_fn(output))
151+
losses = forward_pass(model_autograd, inputs, loss_fn, reduce_to_vector)
158152
autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters()))
159153

160-
torch.random.manual_seed(0) # Fix randomness for random models
161-
output = model_autogram(inputs)
162-
losses = reduce_to_vector(loss_fn(output))
154+
losses = forward_pass(model_autogram, inputs, loss_fn, reduce_to_vector)
163155
autogram_gramian = engine.compute_gramian(losses)
164156

165157
assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=3e-5)
@@ -255,26 +247,18 @@ def test_compute_gramian_various_output_shapes(
255247

256248
factory = ModuleFactory(Ndim2Output)
257249
model_autograd, model_autogram = factory(), factory()
258-
input_shapes, output_shapes = get_in_out_shapes(model_autograd)
259-
260-
engine = Engine(model_autogram, batch_dim=batch_dim)
261-
262-
inputs = make_tensors(batch_size, input_shapes)
263-
targets = make_tensors(batch_size, output_shapes)
250+
inputs, targets = make_inputs_and_targets(model_autograd, batch_size)
264251
loss_fn = make_mse_loss_fn(targets)
265252

266-
torch.random.manual_seed(0) # Fix randomness for random models
267-
output = model_autograd(inputs)
268-
losses = reduction(loss_fn(output))
253+
losses = forward_pass(model_autograd, inputs, loss_fn, reduction)
269254
reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination)
270255
# Go back to a vector so that compute_gramian_with_autograd works
271256
loss_vector = reshaped_losses.reshape([-1])
272257
autograd_gramian = compute_gramian_with_autograd(loss_vector, list(model_autograd.parameters()))
273258
expected_gramian = reshape_gramian(autograd_gramian, list(reshaped_losses.shape))
274259

275-
torch.random.manual_seed(0) # Fix randomness for random models
276-
output = model_autogram(inputs)
277-
losses = reduction(loss_fn(output))
260+
engine = Engine(model_autogram, batch_dim=batch_dim)
261+
losses = forward_pass(model_autogram, inputs, loss_fn, reduction)
278262
reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination)
279263
autogram_gramian = engine.compute_gramian(reshaped_losses)
280264

@@ -296,30 +280,20 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int
296280
the model parameters is specified.
297281
"""
298282

299-
factory = ModuleFactory(SimpleBranched)
300-
model = factory()
301-
input_shapes, output_shapes = get_in_out_shapes(model)
283+
model = SimpleBranched()
302284
batch_size = 64
303-
304-
input = make_tensors(batch_size, input_shapes)
305-
targets = make_tensors(batch_size, output_shapes)
285+
inputs, targets = make_inputs_and_targets(model, batch_size)
306286
loss_fn = make_mse_loss_fn(targets)
307-
308-
output = model(input)
309-
losses = reduce_to_vector(loss_fn(output))
310-
311287
gramian_modules = [model.get_submodule(name) for name in gramian_module_names]
312288
gramian_params = []
313289
for m in gramian_modules:
314290
gramian_params += list(m.parameters())
315291

292+
losses = forward_pass(model, inputs, loss_fn, reduce_to_vector)
316293
autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True)
317-
torch.manual_seed(0)
318294

319295
engine = Engine(*gramian_modules, batch_dim=batch_dim)
320-
321-
output = model(input)
322-
losses = reduce_to_vector(loss_fn(output))
296+
losses = forward_pass(model, inputs, loss_fn, reduce_to_vector)
323297
gramian = engine.compute_gramian(losses)
324298

325299
assert_close(gramian, autograd_gramian)
@@ -331,22 +305,15 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch
331305
"""Tests that the autogram engine doesn't raise any error during several IWRM iterations."""
332306

333307
n_iter = 3
334-
335308
model = factory()
336-
input_shapes, output_shapes = get_in_out_shapes(model)
337-
338309
weighting = UPGradWeighting()
339-
340310
engine = Engine(model, batch_dim=batch_dim)
341311
optimizer = SGD(model.parameters(), lr=1e-7)
342312

343313
for i in range(n_iter):
344-
inputs = make_tensors(batch_size, input_shapes)
345-
targets = make_tensors(batch_size, output_shapes)
314+
inputs, targets = make_inputs_and_targets(model, batch_size)
346315
loss_fn = make_mse_loss_fn(targets)
347-
348-
autogram_forward_backward(model, engine, weighting, inputs, loss_fn)
349-
316+
autogram_forward_backward(model, inputs, loss_fn, engine, weighting)
350317
optimizer.step()
351318
model.zero_grad()
352319

@@ -363,29 +330,22 @@ def test_autograd_while_modules_are_hooked(
363330
"""
364331

365332
model, model_autogram = factory(), factory()
366-
input_shapes, output_shapes = get_in_out_shapes(model)
367-
368-
input = make_tensors(batch_size, input_shapes)
369-
targets = make_tensors(batch_size, output_shapes)
333+
inputs, targets = make_inputs_and_targets(model, batch_size)
370334
loss_fn = make_mse_loss_fn(targets)
371335

372-
torch.manual_seed(0) # Fix randomness for random models
373-
autograd_forward_backward(model, input, loss_fn)
336+
autograd_forward_backward(model, inputs, loss_fn)
374337
autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}
375338

376339
# Hook modules and optionally compute the Gramian
377340
engine = Engine(model_autogram, batch_dim=batch_dim)
378341
if use_engine:
379-
torch.manual_seed(0) # Fix randomness for random models
380-
output = model_autogram(input)
381-
losses = reduce_to_vector(loss_fn(output))
342+
losses = forward_pass(model_autogram, inputs, loss_fn, reduce_to_vector)
382343
_ = engine.compute_gramian(losses)
383344

384345
# Verify that even with the hooked modules, autograd works normally when not using the engine.
385346
# Results should be the same as a normal call to autograd, and no time should be spent computing
386347
# the gramian at all.
387-
torch.manual_seed(0) # Fix randomness for random models
388-
autograd_forward_backward(model_autogram, input, loss_fn)
348+
autograd_forward_backward(model_autogram, inputs, loss_fn)
389349
grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None}
390350

391351
assert_tensor_dicts_are_close(grads, autograd_grads)
@@ -416,12 +376,11 @@ def test_compute_gramian_manual():
416376

417377
in_dims = 18
418378
out_dims = 25
419-
420379
factory = ModuleFactory(Linear, in_dims, out_dims)
421380
model = factory()
422-
engine = Engine(model, batch_dim=None)
423-
424381
input = randn_(in_dims)
382+
383+
engine = Engine(model, batch_dim=None)
425384
output = model(input)
426385
gramian = engine.compute_gramian(output)
427386

@@ -462,21 +421,19 @@ def test_reshape_equivariance(shape: list[int]):
462421

463422
input_size = shape[0]
464423
output_size = prod(shape[1:])
465-
466424
factory = ModuleFactory(Linear, input_size, output_size)
467425
model1, model2 = factory(), factory()
426+
input = randn_([input_size])
468427

469428
engine1 = Engine(model1, batch_dim=None)
470-
engine2 = Engine(model2, batch_dim=None)
471-
472-
input = randn_([input_size])
473429
output = model1(input)
474-
reshaped_output = model2(input).reshape(shape[1:])
475-
476430
gramian = engine1.compute_gramian(output)
477-
reshaped_gramian = engine2.compute_gramian(reshaped_output)
478431
expected_reshaped_gramian = reshape_gramian(gramian, shape[1:])
479432

433+
engine2 = Engine(model2, batch_dim=None)
434+
reshaped_output = model2(input).reshape(shape[1:])
435+
reshaped_gramian = engine2.compute_gramian(reshaped_output)
436+
480437
assert_close(reshaped_gramian, expected_reshaped_gramian)
481438

482439

@@ -502,21 +459,19 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
502459

503460
input_size = shape[0]
504461
output_size = prod(shape[1:])
505-
506462
factory = ModuleFactory(Linear, input_size, output_size)
507463
model1, model2 = factory(), factory()
464+
input = randn_([input_size])
508465

509466
engine1 = Engine(model1, batch_dim=None)
510-
engine2 = Engine(model2, batch_dim=None)
511-
512-
input = randn_([input_size])
513467
output = model1(input).reshape(shape[1:])
514-
moved_output = model2(input).reshape(shape[1:]).movedim(source, destination)
515-
516468
gramian = engine1.compute_gramian(output)
517-
moved_gramian = engine2.compute_gramian(moved_output)
518469
expected_moved_gramian = movedim_gramian(gramian, source, destination)
519470

471+
engine2 = Engine(model2, batch_dim=None)
472+
moved_output = model2(input).reshape(shape[1:]).movedim(source, destination)
473+
moved_gramian = engine2.compute_gramian(moved_output)
474+
520475
assert_close(moved_gramian, expected_moved_gramian)
521476

522477

@@ -545,18 +500,16 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
545500
input_size = prod(non_batched_shape)
546501
batch_size = shape[batch_dim]
547502
output_size = input_size
548-
549503
factory = ModuleFactory(Linear, input_size, output_size)
550504
model1, model2 = factory(), factory()
505+
input = randn_([batch_size, input_size])
551506

552507
engine1 = Engine(model1, batch_dim=batch_dim)
553-
engine2 = Engine(model2, batch_dim=None)
554-
555-
input = randn_([batch_size, input_size])
556508
output1 = model1(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim)
557-
output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim)
558-
559509
gramian1 = engine1.compute_gramian(output1)
510+
511+
engine2 = Engine(model2, batch_dim=None)
512+
output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim)
560513
gramian2 = engine2.compute_gramian(output2)
561514

562515
assert_close(gramian1, gramian2)
@@ -573,24 +526,15 @@ def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: i
573526
"""
574527

575528
model_0, model_none = factory(), factory()
576-
input_shapes, output_shapes = get_in_out_shapes(model_0)
577-
578-
engine_0 = Engine(model_0, batch_dim=0)
579-
engine_none = Engine(model_none, batch_dim=None)
580-
581-
inputs = make_tensors(batch_size, input_shapes)
582-
targets = make_tensors(batch_size, output_shapes)
529+
inputs, targets = make_inputs_and_targets(model_0, batch_size)
583530
loss_fn = make_mse_loss_fn(targets)
584531

585-
torch.random.manual_seed(0) # Fix randomness for random models
586-
output = model_0(inputs)
587-
losses_0 = reduce_to_vector(loss_fn(output))
588-
589-
torch.random.manual_seed(0) # Fix randomness for random models
590-
output = model_none(inputs)
591-
losses_none = reduce_to_vector(loss_fn(output))
592-
532+
engine_0 = Engine(model_0, batch_dim=0)
533+
losses_0 = forward_pass(model_0, inputs, loss_fn, reduce_to_vector)
593534
gramian_0 = engine_0.compute_gramian(losses_0)
535+
536+
engine_none = Engine(model_none, batch_dim=None)
537+
losses_none = forward_pass(model_none, inputs, loss_fn, reduce_to_vector)
594538
gramian_none = engine_none.compute_gramian(losses_none)
595539

596540
assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5)

tests/utils/architectures.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import Tensor, nn
77
from torch.nn import Flatten, ReLU
88
from torch.utils._pytree import PyTree
9+
from utils.contexts import fork_rng
910

1011

1112
class ModuleFactory:
@@ -15,9 +16,7 @@ def __init__(self, architecture: type[nn.Module], *args, **kwargs):
1516
self.kwargs = kwargs
1617

1718
def __call__(self) -> nn.Module:
18-
devices = [DEVICE] if DEVICE.type == "cuda" else []
19-
with torch.random.fork_rng(devices=devices, device_type=DEVICE.type):
20-
torch.random.manual_seed(0)
19+
with fork_rng(seed=0):
2120
return self.architecture(*self.args, **self.kwargs).to(device=DEVICE)
2221

2322
def __str__(self) -> str:

tests/utils/contexts.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
from contextlib import AbstractContextManager
2-
from typing import TypeAlias
1+
from collections.abc import Generator
2+
from contextlib import AbstractContextManager, contextmanager
3+
from typing import Any, TypeAlias
4+
5+
import torch
6+
from device import DEVICE
37

48
ExceptionContext: TypeAlias = AbstractContextManager[Exception | None]
9+
10+
11+
@contextmanager
12+
def fork_rng(seed: int = 0) -> Generator[Any, None, None]:
13+
devices = [DEVICE] if DEVICE.type == "cuda" else []
14+
with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx:
15+
torch.manual_seed(seed)
16+
yield ctx

0 commit comments

Comments
 (0)