Skip to content

Commit 7a95b96

Browse files
committed
test(autogram): Add a way to test against no cross-terms (#467)
* Add CloneParams context to consider each parameter usage on a per-module-usage basis. * Add _get_losses_and_params_with_cross_terms, _get_losses_and_params_without_cross_terms, and _get_losses_and_params to select between both.
1 parent 6ea78c0 commit 7a95b96

File tree

2 files changed

+108
-5
lines changed

2 files changed

+108
-5
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import torch
77
from pytest import mark, param
88
from torch import Tensor
9-
from torch.nn import BatchNorm2d, InstanceNorm2d, Linear
9+
from torch.nn import BatchNorm2d, InstanceNorm2d, Linear, Module, Parameter
1010
from torch.optim import SGD
1111
from torch.testing import assert_close
12+
from torch.utils._pytree import PyTree
1213
from utils.architectures import (
1314
AlexNet,
1415
Cifar10Model,
@@ -64,6 +65,7 @@
6465
)
6566
from utils.dict_assertions import assert_tensor_dicts_are_close
6667
from utils.forward_backwards import (
68+
CloneParams,
6769
autograd_forward_backward,
6870
autogram_forward_backward,
6971
compute_gramian,
@@ -148,15 +150,42 @@ def _assert_gramian_is_equivalent_to_autograd(
148150
inputs, targets = make_inputs_and_targets(model_autograd, batch_size)
149151
loss_fn = make_mse_loss_fn(targets)
150152

151-
losses = forward_pass(model_autograd, inputs, loss_fn, reduce_to_vector)
152-
autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters()))
153+
losses, params = _get_losses_and_params(model_autograd, inputs, loss_fn, reduce_to_vector)
154+
autograd_gramian = compute_gramian_with_autograd(losses, params)
153155

154156
losses = forward_pass(model_autogram, inputs, loss_fn, reduce_to_vector)
155157
autogram_gramian = engine.compute_gramian(losses)
156158

157159
assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=3e-5)
158160

159161

162+
def _get_losses_and_params_with_cross_terms(
163+
model: Module,
164+
inputs: PyTree,
165+
loss_fn: Callable[[PyTree], list[Tensor]],
166+
reduction: Callable[[list[Tensor]], Tensor],
167+
) -> tuple[Tensor, list[Parameter]]:
168+
losses = forward_pass(model, inputs, loss_fn, reduction)
169+
params = list(model.parameters())
170+
return losses, params
171+
172+
173+
def _get_losses_and_params_without_cross_terms(
174+
model: Module,
175+
inputs: PyTree,
176+
loss_fn: Callable[[PyTree], list[Tensor]],
177+
reduction: Callable[[list[Tensor]], Tensor],
178+
) -> tuple[Tensor, list[Parameter]]:
179+
# Not considering cross-terms (except intra-module parameter reuse):
180+
with CloneParams(model) as params:
181+
losses = forward_pass(model, inputs, loss_fn, reduction)
182+
183+
return losses, params
184+
185+
186+
_get_losses_and_params = _get_losses_and_params_with_cross_terms
187+
188+
160189
@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS)
161190
@mark.parametrize("batch_dim", [0, None])
162191
def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int | None):
@@ -250,11 +279,11 @@ def test_compute_gramian_various_output_shapes(
250279
inputs, targets = make_inputs_and_targets(model_autograd, batch_size)
251280
loss_fn = make_mse_loss_fn(targets)
252281

253-
losses = forward_pass(model_autograd, inputs, loss_fn, reduction)
282+
losses, params = _get_losses_and_params(model_autograd, inputs, loss_fn, reduction)
254283
reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination)
255284
# Go back to a vector so that compute_gramian_with_autograd works
256285
loss_vector = reshaped_losses.reshape([-1])
257-
autograd_gramian = compute_gramian_with_autograd(loss_vector, list(model_autograd.parameters()))
286+
autograd_gramian = compute_gramian_with_autograd(loss_vector, params)
258287
expected_gramian = reshape_gramian(autograd_gramian, list(reshaped_losses.shape))
259288

260289
engine = Engine(model_autogram, batch_dim=batch_dim)
@@ -289,6 +318,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int
289318
for m in gramian_modules:
290319
gramian_params += list(m.parameters())
291320

321+
# This includes cross-terms, but the model has no parameter reuse.
292322
losses = forward_pass(model, inputs, loss_fn, reduce_to_vector)
293323
autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True)
294324

tests/utils/forward_backwards.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch import Tensor, nn, vmap
55
from torch.nn.functional import mse_loss
66
from torch.utils._pytree import PyTree, tree_flatten, tree_map
7+
from torch.utils.hooks import RemovableHandle
78
from utils.architectures import get_in_out_shapes
89
from utils.contexts import fork_rng
910

@@ -144,3 +145,75 @@ def compute_gramian(matrix: Tensor) -> Tensor:
144145
indices = list(range(matrix.ndim))
145146
transposed_matrix = matrix.movedim(indices, indices[::-1])
146147
return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0]))
148+
149+
150+
class CloneParams:
151+
"""
152+
ContextManager enabling the computation of per-usage gradients.
153+
154+
For each submodule with direct trainable parameters, registers:
155+
- A pre-hook that clones the params before using them, so that gradients will be computed with
156+
respect to the cloned params.
157+
- A post-hook that restores the original params.
158+
159+
The list of clones is returned so that we know where to find the .grad values corresponding to
160+
each individual usage of a parameter.
161+
162+
Exiting this context manager takes care of removing hooks and restoring the original params (in
163+
case an exception occurred before the post-hook could do it).
164+
165+
Note that this does not work for intra-module parameter reuse, which would require a node-based
166+
algorithm rather than a module-based algorithm.
167+
"""
168+
169+
def __init__(self, model: nn.Module):
170+
self.model = model
171+
self.clones = list[nn.Parameter]()
172+
self._module_to_original_params = dict[nn.Module, dict[str, nn.Parameter]]()
173+
self._handles: list[RemovableHandle] = []
174+
175+
def __enter__(self) -> list[nn.Parameter]:
176+
"""Register hooks and return list of (orig_param_id, clone_param)."""
177+
178+
def pre_hook(module: nn.Module, _) -> None:
179+
self._module_to_original_params[module] = {}
180+
for name, param in module.named_parameters():
181+
if param is None or not param.requires_grad:
182+
continue
183+
self._module_to_original_params[module][name] = param
184+
clone = nn.Parameter(param.detach().clone().requires_grad_())
185+
self._set_module_param(module, name, clone)
186+
self.clones.append(clone)
187+
188+
def post_hook(module: nn.Module, _, __) -> None:
189+
self._restore_original_params(module)
190+
191+
# Register hooks on all modules with direct trainable params
192+
for mod in self.model.modules():
193+
if any(p.requires_grad for p in mod.parameters(recurse=False)):
194+
self._handles.append(mod.register_forward_pre_hook(pre_hook))
195+
self._handles.append(mod.register_forward_hook(post_hook))
196+
197+
return self.clones
198+
199+
def __exit__(self, exc_type, exc_val, exc_tb):
200+
"""Remove hooks and restore parameters."""
201+
for handle in self._handles:
202+
handle.remove()
203+
for module in self.model.modules():
204+
self._restore_original_params(module)
205+
206+
return False # don’t suppress exceptions
207+
208+
def _restore_original_params(self, module: nn.Module):
209+
original_params = self._module_to_original_params.pop(module, {})
210+
for name, param in original_params.items():
211+
self._set_module_param(module, name, param)
212+
213+
@staticmethod
214+
def _set_module_param(module: nn.Module, name: str, param: nn.Parameter) -> None:
215+
name_parts = name.split(".")
216+
for module_name in name_parts[:-1]:
217+
module = module.get_submodule(module_name)
218+
param_name = name_parts[-1]
219+
setattr(module, param_name, param)

0 commit comments

Comments
 (0)