Skip to content

Commit d3a9e8b

Browse files
test(autogram): Revamp engine tests (#418)
* Split test_equivalence_autojac_autogram into test_compute_gramian and test_iwrm_steps_with_autogram * Change test_partial_autogram into test_compute_partial_gramian * Add compute_gramian_with_autograd and autograd_gramian_forward_backward in forward_backwards.py * Compare to autograd gramian instead of autojac in test_compute_gramian and test_compute_partial_gramian * Simplify test_autograd_while_modules_are_hooked * Add garbage collection step in init functions of speed tests * Add speed test for autograd_gramian_forward_backward
1 parent 5acab1d commit d3a9e8b

File tree

3 files changed

+156
-124
lines changed

3 files changed

+156
-124
lines changed

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import time
23

34
import torch
@@ -15,6 +16,7 @@
1516
)
1617
from utils.forward_backwards import (
1718
autograd_forward_backward,
19+
autograd_gramian_forward_backward,
1820
autogram_forward_backward,
1921
autojac_forward_backward,
2022
make_mse_loss_fn,
@@ -31,8 +33,8 @@
3133
(AlexNet, 8),
3234
(InstanceNormResNet18, 16),
3335
(GroupNormMobileNetV3Small, 16),
34-
(SqueezeNet, 16),
35-
(InstanceNormMobileNetV2, 8),
36+
(SqueezeNet, 4),
37+
(InstanceNormMobileNetV2, 2),
3638
]
3739

3840

@@ -58,20 +60,31 @@ def fn_autograd():
5860

5961
def init_fn_autograd():
6062
torch.cuda.empty_cache()
63+
gc.collect()
6164
fn_autograd()
6265

66+
def fn_autograd_gramian():
67+
autograd_gramian_forward_backward(model, inputs, list(model.parameters()), loss_fn, W)
68+
69+
def init_fn_autograd_gramian():
70+
torch.cuda.empty_cache()
71+
gc.collect()
72+
fn_autograd_gramian()
73+
6374
def fn_autojac():
6475
autojac_forward_backward(model, inputs, loss_fn, A)
6576

6677
def init_fn_autojac():
6778
torch.cuda.empty_cache()
79+
gc.collect()
6880
fn_autojac()
6981

7082
def fn_autogram():
7183
autogram_forward_backward(model, engine, W, inputs, loss_fn)
7284

7385
def init_fn_autogram():
7486
torch.cuda.empty_cache()
87+
gc.collect()
7588
fn_autogram()
7689

7790
def optionally_cuda_sync():
@@ -91,6 +104,16 @@ def post_fn():
91104
print(autograd_times)
92105
print()
93106

107+
autograd_gramian_times = torch.tensor(
108+
time_call(fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs)
109+
)
110+
print(
111+
f"autograd gramian times (avg = {autograd_gramian_times.mean():.5f}, std = "
112+
f"{autograd_gramian_times.std():.5f}"
113+
)
114+
print(autograd_gramian_times)
115+
print()
116+
94117
autojac_times = torch.tensor(time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs))
95118
print(f"autojac times (avg = {autojac_times.mean():.5f}, std = {autojac_times.std():.5f}")
96119
print(autojac_times)

tests/unit/autogram/test_engine.py

Lines changed: 91 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytest import mark, param
66
from torch import nn
77
from torch.optim import SGD
8+
from torch.testing import assert_close
89
from unit.conftest import DEVICE
910
from utils.architectures import (
1011
AlexNet,
@@ -52,15 +53,13 @@
5253
from utils.forward_backwards import (
5354
autograd_forward_backward,
5455
autogram_forward_backward,
55-
autojac_forward_backward,
56+
compute_gramian_with_autograd,
5657
make_mse_loss_fn,
5758
)
5859
from utils.tensors import make_tensors
5960

60-
from torchjd.aggregation import UPGrad, UPGradWeighting
61+
from torchjd.aggregation import UPGradWeighting
6162
from torchjd.autogram._engine import Engine
62-
from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet
63-
from torchjd.autojac._transform._aggregate import _Matrixify
6463

6564
PARAMETRIZATIONS = [
6665
(OverlyNested, 32),
@@ -107,110 +106,34 @@
107106

108107

109108
@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
110-
def test_equivalence_autojac_autogram(
111-
architecture: type[ShapedModule],
112-
batch_size: int,
113-
):
114-
"""
115-
Tests that the autogram engine gives the same results as the autojac engine on IWRM for several
116-
JD steps.
117-
"""
118-
119-
n_iter = 3
109+
def test_compute_gramian(architecture: type[ShapedModule], batch_size: int):
110+
"""Tests that the autograd and the autogram engines compute the same gramian."""
120111

121112
input_shapes = architecture.INPUT_SHAPES
122113
output_shapes = architecture.OUTPUT_SHAPES
123114

124-
weighting = UPGradWeighting()
125-
aggregator = UPGrad()
126-
127115
torch.manual_seed(0)
128-
model_autojac = architecture().to(device=DEVICE)
116+
model_autograd = architecture().to(device=DEVICE)
129117
torch.manual_seed(0)
130118
model_autogram = architecture().to(device=DEVICE)
131119

132120
engine = Engine(model_autogram.modules())
133-
optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7)
134-
optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7)
135-
136-
for i in range(n_iter):
137-
inputs = make_tensors(batch_size, input_shapes)
138-
targets = make_tensors(batch_size, output_shapes)
139-
loss_fn = make_mse_loss_fn(targets)
140-
141-
torch.random.manual_seed(0) # Fix randomness for random aggregators and random models
142-
autojac_forward_backward(model_autojac, inputs, loss_fn, aggregator)
143-
expected_grads = {
144-
name: p.grad for name, p in model_autojac.named_parameters() if p.grad is not None
145-
}
146-
147-
torch.random.manual_seed(0) # Fix randomness for random weightings and random models
148-
autogram_forward_backward(model_autogram, engine, weighting, inputs, loss_fn)
149-
grads = {
150-
name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None
151-
}
152-
153-
assert_tensor_dicts_are_close(grads, expected_grads)
154-
155-
optimizer_autojac.step()
156-
model_autojac.zero_grad()
157-
158-
optimizer_autogram.step()
159-
model_autogram.zero_grad()
160-
161-
162-
@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
163-
def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], batch_size: int):
164-
"""
165-
Tests that the hooks added when constructing the engine do not interfere with a simple autograd
166-
call.
167-
"""
168-
169-
input_shapes = architecture.INPUT_SHAPES
170-
output_shapes = architecture.OUTPUT_SHAPES
171121

172-
W = UPGradWeighting()
173-
A = UPGrad()
174-
input = make_tensors(batch_size, input_shapes)
122+
inputs = make_tensors(batch_size, input_shapes)
175123
targets = make_tensors(batch_size, output_shapes)
176124
loss_fn = make_mse_loss_fn(targets)
177125

178-
torch.manual_seed(0)
179-
model = architecture().to(device=DEVICE)
180-
181-
torch.manual_seed(0) # Fix randomness for random models
182-
autojac_forward_backward(model, input, loss_fn, A)
183-
autojac_grads = {
184-
name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None
185-
}
186-
model.zero_grad()
187-
188-
torch.manual_seed(0) # Fix randomness for random models
189-
autograd_forward_backward(model, input, loss_fn)
190-
autograd_grads = {
191-
name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None
192-
}
193-
194-
torch.manual_seed(0)
195-
model_autogram = architecture().to(device=DEVICE)
126+
torch.random.manual_seed(0) # Fix randomness for random models
127+
output = model_autograd(inputs)
128+
losses = loss_fn(output)
129+
autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters()))
196130

197-
# Hook modules and verify that we're equivalent to autojac when using the engine
198-
engine = Engine(model_autogram.modules())
199-
torch.manual_seed(0) # Fix randomness for random models
200-
autogram_forward_backward(model_autogram, engine, W, input, loss_fn)
201-
grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None}
202-
assert_tensor_dicts_are_close(grads, autojac_grads)
203-
model_autogram.zero_grad()
131+
torch.random.manual_seed(0) # Fix randomness for random models
132+
output = model_autogram(inputs)
133+
losses = loss_fn(output)
134+
autogram_gramian = engine.compute_gramian(losses)
204135

205-
# Verify that even with the hooked modules, autograd works normally when not using the engine.
206-
# Results should be the same as a normal call to autograd, and no time should be spent computing
207-
# the gramian at all.
208-
torch.manual_seed(0) # Fix randomness for random models
209-
autograd_forward_backward(model_autogram, input, loss_fn)
210-
assert engine._gramian_accumulator.gramian is None
211-
grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None}
212-
assert_tensor_dicts_are_close(grads, autograd_grads)
213-
model_autogram.zero_grad()
136+
assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5)
214137

215138

216139
def _non_empty_subsets(elements: set) -> list[set]:
@@ -221,20 +144,15 @@ def _non_empty_subsets(elements: set) -> list[set]:
221144

222145

223146
@mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"}))
224-
def test_partial_autogram(gramian_module_names: set[str]):
147+
def test_compute_partial_gramian(gramian_module_names: set[str]):
225148
"""
226-
Tests that partial JD via the autogram engine works similarly as if the gramian was computed via
227-
the autojac engine.
228-
229-
Note that this test is a bit redundant now that we have the Engine interface, because it now
230-
just compares two ways of computing the Gramian, which is independant of the idea of partial JD.
149+
Tests that the autograd and the autogram engines compute the same gramian when only a subset of
150+
the model parameters is specified.
231151
"""
232152

233153
architecture = SimpleBranched
234154
batch_size = 64
235155

236-
weighting = UPGradWeighting()
237-
238156
input_shapes = architecture.INPUT_SHAPES
239157
output_shapes = architecture.OUTPUT_SHAPES
240158

@@ -247,39 +165,91 @@ def test_partial_autogram(gramian_module_names: set[str]):
247165

248166
output = model(input)
249167
losses = loss_fn(output)
250-
losses_ = OrderedSet(losses)
251-
252-
init = Init(losses_)
253-
diag = Diagonalize(losses_)
254168

255169
gramian_modules = [model.get_submodule(name) for name in gramian_module_names]
256-
gramian_params = OrderedSet({})
170+
gramian_params = []
257171
for m in gramian_modules:
258-
gramian_params += OrderedSet(m.parameters())
172+
gramian_params += list(m.parameters())
259173

260-
jac = Jac(losses_, OrderedSet(gramian_params), None, True)
261-
mat = _Matrixify()
262-
transform = mat << jac << diag << init
263-
264-
jacobian_matrices = transform({})
265-
jacobian_matrix = torch.cat(list(jacobian_matrices.values()), dim=1)
266-
gramian = jacobian_matrix @ jacobian_matrix.T
174+
autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True)
267175
torch.manual_seed(0)
268-
losses.backward(weighting(gramian))
269-
270-
expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}
271-
model.zero_grad()
272176

273177
engine = Engine(gramian_modules)
274178

275179
output = model(input)
276180
losses = loss_fn(output)
277181
gramian = engine.compute_gramian(losses)
182+
183+
assert_close(gramian, autograd_gramian)
184+
185+
186+
@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
187+
def test_iwrm_steps_with_autogram(architecture: type[ShapedModule], batch_size: int):
188+
"""Tests that the autogram engine doesn't raise any error during several IWRM iterations."""
189+
190+
n_iter = 3
191+
192+
input_shapes = architecture.INPUT_SHAPES
193+
output_shapes = architecture.OUTPUT_SHAPES
194+
195+
weighting = UPGradWeighting()
196+
197+
model = architecture().to(device=DEVICE)
198+
199+
engine = Engine(model.modules())
200+
optimizer = SGD(model.parameters(), lr=1e-7)
201+
202+
for i in range(n_iter):
203+
inputs = make_tensors(batch_size, input_shapes)
204+
targets = make_tensors(batch_size, output_shapes)
205+
loss_fn = make_mse_loss_fn(targets)
206+
207+
autogram_forward_backward(model, engine, weighting, inputs, loss_fn)
208+
209+
optimizer.step()
210+
model.zero_grad()
211+
212+
213+
@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
214+
@mark.parametrize("compute_gramian", [False, True])
215+
def test_autograd_while_modules_are_hooked(
216+
architecture: type[ShapedModule], batch_size: int, compute_gramian: bool
217+
):
218+
"""
219+
Tests that the hooks added when constructing the engine do not interfere with a simple autograd
220+
call.
221+
"""
222+
223+
input = make_tensors(batch_size, architecture.INPUT_SHAPES)
224+
targets = make_tensors(batch_size, architecture.OUTPUT_SHAPES)
225+
loss_fn = make_mse_loss_fn(targets)
226+
227+
torch.manual_seed(0)
228+
model = architecture().to(device=DEVICE)
278229
torch.manual_seed(0)
279-
losses.backward(weighting(gramian))
230+
model_autogram = architecture().to(device=DEVICE)
231+
232+
torch.manual_seed(0) # Fix randomness for random models
233+
autograd_forward_backward(model, input, loss_fn)
234+
autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}
235+
236+
# Hook modules and optionally compute the Gramian
237+
engine = Engine(model_autogram.modules())
238+
if compute_gramian:
239+
torch.manual_seed(0) # Fix randomness for random models
240+
output = model_autogram(input)
241+
losses = loss_fn(output)
242+
_ = engine.compute_gramian(losses)
280243

281-
grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}
282-
assert_tensor_dicts_are_close(grads, expected_grads)
244+
# Verify that even with the hooked modules, autograd works normally when not using the engine.
245+
# Results should be the same as a normal call to autograd, and no time should be spent computing
246+
# the gramian at all.
247+
torch.manual_seed(0) # Fix randomness for random models
248+
autograd_forward_backward(model_autogram, input, loss_fn)
249+
grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None}
250+
251+
assert_tensor_dicts_are_close(grads, autograd_grads)
252+
assert engine._gramian_accumulator.gramian is None
283253

284254

285255
@mark.parametrize("architecture", [WithRNN, WithModuleTrackingRunningStats])

0 commit comments

Comments
 (0)