|
6 | 6 | import torch |
7 | 7 | from pytest import mark, param |
8 | 8 | from torch import Tensor |
9 | | -from torch.nn import BatchNorm2d, InstanceNorm2d, Linear |
| 9 | +from torch.nn import BatchNorm2d, InstanceNorm2d, Linear, Module, Parameter |
10 | 10 | from torch.optim import SGD |
11 | 11 | from torch.testing import assert_close |
| 12 | +from torch.utils._pytree import PyTree |
12 | 13 | from utils.architectures import ( |
13 | 14 | AlexNet, |
14 | 15 | Cifar10Model, |
|
64 | 65 | ) |
65 | 66 | from utils.dict_assertions import assert_tensor_dicts_are_close |
66 | 67 | from utils.forward_backwards import ( |
| 68 | + CloneParams, |
67 | 69 | autograd_forward_backward, |
68 | 70 | autogram_forward_backward, |
69 | 71 | compute_gramian, |
@@ -148,15 +150,42 @@ def _assert_gramian_is_equivalent_to_autograd( |
148 | 150 | inputs, targets = make_inputs_and_targets(model_autograd, batch_size) |
149 | 151 | loss_fn = make_mse_loss_fn(targets) |
150 | 152 |
|
151 | | - losses = forward_pass(model_autograd, inputs, loss_fn, reduce_to_vector) |
152 | | - autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters())) |
| 153 | + losses, params = _get_losses_and_params(model_autograd, inputs, loss_fn, reduce_to_vector) |
| 154 | + autograd_gramian = compute_gramian_with_autograd(losses, params) |
153 | 155 |
|
154 | 156 | losses = forward_pass(model_autogram, inputs, loss_fn, reduce_to_vector) |
155 | 157 | autogram_gramian = engine.compute_gramian(losses) |
156 | 158 |
|
157 | 159 | assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=3e-5) |
158 | 160 |
|
159 | 161 |
|
| 162 | +def _get_losses_and_params_with_cross_terms( |
| 163 | + model: Module, |
| 164 | + inputs: PyTree, |
| 165 | + loss_fn: Callable[[PyTree], list[Tensor]], |
| 166 | + reduction: Callable[[list[Tensor]], Tensor], |
| 167 | +) -> tuple[Tensor, list[Parameter]]: |
| 168 | + losses = forward_pass(model, inputs, loss_fn, reduction) |
| 169 | + params = list(model.parameters()) |
| 170 | + return losses, params |
| 171 | + |
| 172 | + |
| 173 | +def _get_losses_and_params_without_cross_terms( |
| 174 | + model: Module, |
| 175 | + inputs: PyTree, |
| 176 | + loss_fn: Callable[[PyTree], list[Tensor]], |
| 177 | + reduction: Callable[[list[Tensor]], Tensor], |
| 178 | +) -> tuple[Tensor, list[Parameter]]: |
| 179 | + # Not considering cross-terms (except intra-module parameter reuse): |
| 180 | + with CloneParams(model) as params: |
| 181 | + losses = forward_pass(model, inputs, loss_fn, reduction) |
| 182 | + |
| 183 | + return losses, params |
| 184 | + |
| 185 | + |
| 186 | +_get_losses_and_params = _get_losses_and_params_with_cross_terms |
| 187 | + |
| 188 | + |
160 | 189 | @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) |
161 | 190 | @mark.parametrize("batch_dim", [0, None]) |
162 | 191 | def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int | None): |
@@ -250,11 +279,11 @@ def test_compute_gramian_various_output_shapes( |
250 | 279 | inputs, targets = make_inputs_and_targets(model_autograd, batch_size) |
251 | 280 | loss_fn = make_mse_loss_fn(targets) |
252 | 281 |
|
253 | | - losses = forward_pass(model_autograd, inputs, loss_fn, reduction) |
| 282 | + losses, params = _get_losses_and_params(model_autograd, inputs, loss_fn, reduction) |
254 | 283 | reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination) |
255 | 284 | # Go back to a vector so that compute_gramian_with_autograd works |
256 | 285 | loss_vector = reshaped_losses.reshape([-1]) |
257 | | - autograd_gramian = compute_gramian_with_autograd(loss_vector, list(model_autograd.parameters())) |
| 286 | + autograd_gramian = compute_gramian_with_autograd(loss_vector, params) |
258 | 287 | expected_gramian = reshape_gramian(autograd_gramian, list(reshaped_losses.shape)) |
259 | 288 |
|
260 | 289 | engine = Engine(model_autogram, batch_dim=batch_dim) |
@@ -289,6 +318,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int |
289 | 318 | for m in gramian_modules: |
290 | 319 | gramian_params += list(m.parameters()) |
291 | 320 |
|
| 321 | + # This includes cross-terms, but the model has no parameter reuse. |
292 | 322 | losses = forward_pass(model, inputs, loss_fn, reduce_to_vector) |
293 | 323 | autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) |
294 | 324 |
|
|
0 commit comments