Skip to content

Commit d3db44f

Browse files
authored
Merge branch 'main' into linear-gramian-computer
2 parents 568bea0 + 7d5ec94 commit d3db44f

File tree

6 files changed

+114
-138
lines changed

6 files changed

+114
-138
lines changed

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
NoFreeParam,
1515
SqueezeNet,
1616
WithTransformerLarge,
17-
get_in_out_shapes,
1817
)
1918
from utils.forward_backwards import (
2019
autograd_forward_backward,
@@ -23,7 +22,7 @@
2322
autojac_forward_backward,
2423
make_mse_loss_fn,
2524
)
26-
from utils.tensors import make_tensors
25+
from utils.tensors import make_inputs_and_targets
2726

2827
from torchjd.aggregation import Mean
2928
from torchjd.autogram import Engine
@@ -43,9 +42,7 @@
4342

4443
def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int):
4544
model = factory()
46-
input_shapes, output_shapes = get_in_out_shapes(model)
47-
inputs = make_tensors(batch_size, input_shapes)
48-
targets = make_tensors(batch_size, output_shapes)
45+
inputs, targets = make_inputs_and_targets(model, batch_size)
4946
loss_fn = make_mse_loss_fn(targets)
5047

5148
A = Mean()
@@ -64,7 +61,7 @@ def init_fn_autograd():
6461
fn_autograd()
6562

6663
def fn_autograd_gramian():
67-
autograd_gramian_forward_backward(model, inputs, list(model.parameters()), loss_fn, W)
64+
autograd_gramian_forward_backward(model, inputs, loss_fn, W)
6865

6966
def init_fn_autograd_gramian():
7067
torch.cuda.empty_cache()
@@ -80,7 +77,7 @@ def init_fn_autojac():
8077
fn_autojac()
8178

8279
def fn_autogram():
83-
autogram_forward_backward(model, engine, W, inputs, loss_fn)
80+
autogram_forward_backward(model, inputs, loss_fn, engine, W)
8481

8582
def init_fn_autogram():
8683
torch.cuda.empty_cache()

tests/unit/autogram/test_engine.py

Lines changed: 45 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from pytest import mark, param
88
from torch import Tensor
9-
from torch.nn import RNN, BatchNorm2d, InstanceNorm2d, Linear
9+
from torch.nn import BatchNorm2d, InstanceNorm2d, Linear
1010
from torch.optim import SGD
1111
from torch.testing import assert_close
1212
from utils.architectures import (
@@ -56,25 +56,26 @@
5656
WithModuleWithStringOutput,
5757
WithMultiHeadAttention,
5858
WithNoTensorOutput,
59+
WithRNN,
5960
WithSideEffect,
6061
WithSomeFrozenModule,
6162
WithTransformer,
6263
WithTransformerLarge,
63-
get_in_out_shapes,
6464
)
6565
from utils.dict_assertions import assert_tensor_dicts_are_close
6666
from utils.forward_backwards import (
6767
autograd_forward_backward,
6868
autogram_forward_backward,
6969
compute_gramian,
7070
compute_gramian_with_autograd,
71+
forward_pass,
7172
make_mse_loss_fn,
7273
reduce_to_first_tensor,
7374
reduce_to_matrix,
7475
reduce_to_scalar,
7576
reduce_to_vector,
7677
)
77-
from utils.tensors import make_tensors, ones_, randn_, zeros_
78+
from utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_
7879

7980
from torchjd.aggregation import UPGradWeighting
8081
from torchjd.autogram._engine import Engine
@@ -143,22 +144,14 @@ def _assert_gramian_is_equivalent_to_autograd(
143144
factory: ModuleFactory, batch_size: int, batch_dim: int | None
144145
):
145146
model_autograd, model_autogram = factory(), factory()
146-
input_shapes, output_shapes = get_in_out_shapes(model_autograd)
147-
148147
engine = Engine(model_autogram, batch_dim=batch_dim)
149-
150-
inputs = make_tensors(batch_size, input_shapes)
151-
targets = make_tensors(batch_size, output_shapes)
148+
inputs, targets = make_inputs_and_targets(model_autograd, batch_size)
152149
loss_fn = make_mse_loss_fn(targets)
153150

154-
torch.random.manual_seed(0) # Fix randomness for random models
155-
output = model_autograd(inputs)
156-
losses = reduce_to_vector(loss_fn(output))
151+
losses = forward_pass(model_autograd, inputs, loss_fn, reduce_to_vector)
157152
autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters()))
158153

159-
torch.random.manual_seed(0) # Fix randomness for random models
160-
output = model_autogram(inputs)
161-
losses = reduce_to_vector(loss_fn(output))
154+
losses = forward_pass(model_autogram, inputs, loss_fn, reduce_to_vector)
162155
autogram_gramian = engine.compute_gramian(losses)
163156

164157
assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=3e-5)
@@ -179,10 +172,7 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int
179172
ModuleFactory(WithSideEffect),
180173
ModuleFactory(Randomness),
181174
ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True),
182-
param(
183-
ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True),
184-
marks=mark.xfail_if_cuda,
185-
),
175+
param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda),
186176
],
187177
)
188178
@mark.parametrize("batch_size", [1, 3, 32])
@@ -257,26 +247,18 @@ def test_compute_gramian_various_output_shapes(
257247

258248
factory = ModuleFactory(Ndim2Output)
259249
model_autograd, model_autogram = factory(), factory()
260-
input_shapes, output_shapes = get_in_out_shapes(model_autograd)
261-
262-
engine = Engine(model_autogram, batch_dim=batch_dim)
263-
264-
inputs = make_tensors(batch_size, input_shapes)
265-
targets = make_tensors(batch_size, output_shapes)
250+
inputs, targets = make_inputs_and_targets(model_autograd, batch_size)
266251
loss_fn = make_mse_loss_fn(targets)
267252

268-
torch.random.manual_seed(0) # Fix randomness for random models
269-
output = model_autograd(inputs)
270-
losses = reduction(loss_fn(output))
253+
losses = forward_pass(model_autograd, inputs, loss_fn, reduction)
271254
reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination)
272255
# Go back to a vector so that compute_gramian_with_autograd works
273256
loss_vector = reshaped_losses.reshape([-1])
274257
autograd_gramian = compute_gramian_with_autograd(loss_vector, list(model_autograd.parameters()))
275258
expected_gramian = reshape_gramian(autograd_gramian, list(reshaped_losses.shape))
276259

277-
torch.random.manual_seed(0) # Fix randomness for random models
278-
output = model_autogram(inputs)
279-
losses = reduction(loss_fn(output))
260+
engine = Engine(model_autogram, batch_dim=batch_dim)
261+
losses = forward_pass(model_autogram, inputs, loss_fn, reduction)
280262
reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination)
281263
autogram_gramian = engine.compute_gramian(reshaped_losses)
282264

@@ -298,30 +280,20 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int
298280
the model parameters is specified.
299281
"""
300282

301-
factory = ModuleFactory(SimpleBranched)
302-
model = factory()
303-
input_shapes, output_shapes = get_in_out_shapes(model)
283+
model = SimpleBranched()
304284
batch_size = 64
305-
306-
input = make_tensors(batch_size, input_shapes)
307-
targets = make_tensors(batch_size, output_shapes)
285+
inputs, targets = make_inputs_and_targets(model, batch_size)
308286
loss_fn = make_mse_loss_fn(targets)
309-
310-
output = model(input)
311-
losses = reduce_to_vector(loss_fn(output))
312-
313287
gramian_modules = [model.get_submodule(name) for name in gramian_module_names]
314288
gramian_params = []
315289
for m in gramian_modules:
316290
gramian_params += list(m.parameters())
317291

292+
losses = forward_pass(model, inputs, loss_fn, reduce_to_vector)
318293
autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True)
319-
torch.manual_seed(0)
320294

321295
engine = Engine(*gramian_modules, batch_dim=batch_dim)
322-
323-
output = model(input)
324-
losses = reduce_to_vector(loss_fn(output))
296+
losses = forward_pass(model, inputs, loss_fn, reduce_to_vector)
325297
gramian = engine.compute_gramian(losses)
326298

327299
assert_close(gramian, autograd_gramian)
@@ -333,22 +305,15 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch
333305
"""Tests that the autogram engine doesn't raise any error during several IWRM iterations."""
334306

335307
n_iter = 3
336-
337308
model = factory()
338-
input_shapes, output_shapes = get_in_out_shapes(model)
339-
340309
weighting = UPGradWeighting()
341-
342310
engine = Engine(model, batch_dim=batch_dim)
343311
optimizer = SGD(model.parameters(), lr=1e-7)
344312

345313
for i in range(n_iter):
346-
inputs = make_tensors(batch_size, input_shapes)
347-
targets = make_tensors(batch_size, output_shapes)
314+
inputs, targets = make_inputs_and_targets(model, batch_size)
348315
loss_fn = make_mse_loss_fn(targets)
349-
350-
autogram_forward_backward(model, engine, weighting, inputs, loss_fn)
351-
316+
autogram_forward_backward(model, inputs, loss_fn, engine, weighting)
352317
optimizer.step()
353318
model.zero_grad()
354319

@@ -365,29 +330,22 @@ def test_autograd_while_modules_are_hooked(
365330
"""
366331

367332
model, model_autogram = factory(), factory()
368-
input_shapes, output_shapes = get_in_out_shapes(model)
369-
370-
input = make_tensors(batch_size, input_shapes)
371-
targets = make_tensors(batch_size, output_shapes)
333+
inputs, targets = make_inputs_and_targets(model, batch_size)
372334
loss_fn = make_mse_loss_fn(targets)
373335

374-
torch.manual_seed(0) # Fix randomness for random models
375-
autograd_forward_backward(model, input, loss_fn)
336+
autograd_forward_backward(model, inputs, loss_fn)
376337
autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}
377338

378339
# Hook modules and optionally compute the Gramian
379340
engine = Engine(model_autogram, batch_dim=batch_dim)
380341
if use_engine:
381-
torch.manual_seed(0) # Fix randomness for random models
382-
output = model_autogram(input)
383-
losses = reduce_to_vector(loss_fn(output))
342+
losses = forward_pass(model_autogram, inputs, loss_fn, reduce_to_vector)
384343
_ = engine.compute_gramian(losses)
385344

386345
# Verify that even with the hooked modules, autograd works normally when not using the engine.
387346
# Results should be the same as a normal call to autograd, and no time should be spent computing
388347
# the gramian at all.
389-
torch.manual_seed(0) # Fix randomness for random models
390-
autograd_forward_backward(model_autogram, input, loss_fn)
348+
autograd_forward_backward(model_autogram, inputs, loss_fn)
391349
grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None}
392350

393351
assert_tensor_dicts_are_close(grads, autograd_grads)
@@ -398,7 +356,7 @@ def test_autograd_while_modules_are_hooked(
398356
["factory", "batch_dim"],
399357
[
400358
(ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0),
401-
(ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), 0),
359+
param(ModuleFactory(WithRNN), 0),
402360
(ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0),
403361
],
404362
)
@@ -418,12 +376,11 @@ def test_compute_gramian_manual():
418376

419377
in_dims = 18
420378
out_dims = 25
421-
422379
factory = ModuleFactory(Linear, in_dims, out_dims)
423380
model = factory()
424-
engine = Engine(model, batch_dim=None)
425-
426381
input = randn_(in_dims)
382+
383+
engine = Engine(model, batch_dim=None)
427384
output = model(input)
428385
gramian = engine.compute_gramian(output)
429386

@@ -464,21 +421,19 @@ def test_reshape_equivariance(shape: list[int]):
464421

465422
input_size = shape[0]
466423
output_size = prod(shape[1:])
467-
468424
factory = ModuleFactory(Linear, input_size, output_size)
469425
model1, model2 = factory(), factory()
426+
input = randn_([input_size])
470427

471428
engine1 = Engine(model1, batch_dim=None)
472-
engine2 = Engine(model2, batch_dim=None)
473-
474-
input = randn_([input_size])
475429
output = model1(input)
476-
reshaped_output = model2(input).reshape(shape[1:])
477-
478430
gramian = engine1.compute_gramian(output)
479-
reshaped_gramian = engine2.compute_gramian(reshaped_output)
480431
expected_reshaped_gramian = reshape_gramian(gramian, shape[1:])
481432

433+
engine2 = Engine(model2, batch_dim=None)
434+
reshaped_output = model2(input).reshape(shape[1:])
435+
reshaped_gramian = engine2.compute_gramian(reshaped_output)
436+
482437
assert_close(reshaped_gramian, expected_reshaped_gramian)
483438

484439

@@ -504,21 +459,19 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
504459

505460
input_size = shape[0]
506461
output_size = prod(shape[1:])
507-
508462
factory = ModuleFactory(Linear, input_size, output_size)
509463
model1, model2 = factory(), factory()
464+
input = randn_([input_size])
510465

511466
engine1 = Engine(model1, batch_dim=None)
512-
engine2 = Engine(model2, batch_dim=None)
513-
514-
input = randn_([input_size])
515467
output = model1(input).reshape(shape[1:])
516-
moved_output = model2(input).reshape(shape[1:]).movedim(source, destination)
517-
518468
gramian = engine1.compute_gramian(output)
519-
moved_gramian = engine2.compute_gramian(moved_output)
520469
expected_moved_gramian = movedim_gramian(gramian, source, destination)
521470

471+
engine2 = Engine(model2, batch_dim=None)
472+
moved_output = model2(input).reshape(shape[1:]).movedim(source, destination)
473+
moved_gramian = engine2.compute_gramian(moved_output)
474+
522475
assert_close(moved_gramian, expected_moved_gramian)
523476

524477

@@ -547,18 +500,16 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
547500
input_size = prod(non_batched_shape)
548501
batch_size = shape[batch_dim]
549502
output_size = input_size
550-
551503
factory = ModuleFactory(Linear, input_size, output_size)
552504
model1, model2 = factory(), factory()
505+
input = randn_([batch_size, input_size])
553506

554507
engine1 = Engine(model1, batch_dim=batch_dim)
555-
engine2 = Engine(model2, batch_dim=None)
556-
557-
input = randn_([batch_size, input_size])
558508
output1 = model1(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim)
559-
output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim)
560-
561509
gramian1 = engine1.compute_gramian(output1)
510+
511+
engine2 = Engine(model2, batch_dim=None)
512+
output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim)
562513
gramian2 = engine2.compute_gramian(output2)
563514

564515
assert_close(gramian1, gramian2)
@@ -575,24 +526,15 @@ def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: i
575526
"""
576527

577528
model_0, model_none = factory(), factory()
578-
input_shapes, output_shapes = get_in_out_shapes(model_0)
579-
580-
engine_0 = Engine(model_0, batch_dim=0)
581-
engine_none = Engine(model_none, batch_dim=None)
582-
583-
inputs = make_tensors(batch_size, input_shapes)
584-
targets = make_tensors(batch_size, output_shapes)
529+
inputs, targets = make_inputs_and_targets(model_0, batch_size)
585530
loss_fn = make_mse_loss_fn(targets)
586531

587-
torch.random.manual_seed(0) # Fix randomness for random models
588-
output = model_0(inputs)
589-
losses_0 = reduce_to_vector(loss_fn(output))
590-
591-
torch.random.manual_seed(0) # Fix randomness for random models
592-
output = model_none(inputs)
593-
losses_none = reduce_to_vector(loss_fn(output))
594-
532+
engine_0 = Engine(model_0, batch_dim=0)
533+
losses_0 = forward_pass(model_0, inputs, loss_fn, reduce_to_vector)
595534
gramian_0 = engine_0.compute_gramian(losses_0)
535+
536+
engine_none = Engine(model_none, batch_dim=None)
537+
losses_none = forward_pass(model_none, inputs, loss_fn, reduce_to_vector)
596538
gramian_none = engine_none.compute_gramian(losses_none)
597539

598540
assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5)

0 commit comments

Comments
 (0)