Skip to content

Commit 30fdc00

Browse files
committed
Add compute_gramian_with_autograd_no_cross_terms and CloneParams
1 parent d3db44f commit 30fdc00

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

tests/utils/forward_backwards.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch import Tensor, nn, vmap
55
from torch.nn.functional import mse_loss
66
from torch.utils._pytree import PyTree, tree_flatten, tree_map
7+
from torch.utils.hooks import RemovableHandle
78
from utils.architectures import get_in_out_shapes
89
from utils.contexts import fork_rng
910

@@ -138,9 +139,109 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]:
138139
return gramian
139140

140141

142+
def compute_gramian_with_autograd_no_cross_terms(
143+
model: nn.Module,
144+
inputs: PyTree,
145+
loss_fn: Callable[[PyTree], list[Tensor]],
146+
):
147+
with CloneParams(model) as usage_clones:
148+
output = model(inputs)
149+
150+
_, expected_output_shapes = get_in_out_shapes(model)
151+
assert tree_map(lambda t: t.shape[1:], output) == expected_output_shapes
152+
153+
loss_tensors = loss_fn(output)
154+
loss_vector = reduce_to_vector(loss_tensors)
155+
156+
def get_vjp(grad_outputs: Tensor) -> list[Tensor]:
157+
grads = torch.autograd.grad(
158+
loss_vector,
159+
[cloned_param for orig_param_id, cloned_param in usage_clones],
160+
grad_outputs=grad_outputs,
161+
retain_graph=False,
162+
allow_unused=True,
163+
)
164+
return [grad for grad in grads if grad is not None]
165+
166+
jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(loss_vector)))
167+
jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians]
168+
169+
gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices])
170+
171+
return gramian
172+
173+
141174
def compute_gramian(matrix: Tensor) -> Tensor:
142175
"""Contracts the last dimension of matrix to make it into a Gramian."""
143176

144177
indices = list(range(matrix.ndim))
145178
transposed_matrix = matrix.movedim(indices, indices[::-1])
146179
return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0]))
180+
181+
182+
class CloneParams:
183+
"""
184+
ContextManager enabling the computation of per-usage gradients.
185+
186+
For each submodule with direct trainable parameters, registers:
187+
- A pre-hook that clones the params before using them, so that gradients will be computed with
188+
respect to the cloned params.
189+
- A post-hook that restores the original params.
190+
191+
The list of clones is returned so that we know where to find the .grad values corresponding to
192+
each individual usage of a parameter.
193+
194+
Exiting this context manager takes care of removing hooks and restoring the original params (in
195+
case an exception occurred before the post-hook could do it).
196+
197+
Note that this does not work for intra-module parameter reuse, which would require a node-based
198+
algorithm rather than a module-based algorithm.
199+
"""
200+
201+
def __init__(self, model: nn.Module):
202+
self.model = model
203+
self.usage_clones: list[tuple[int, nn.Parameter]] = []
204+
self._orig_params_storage: dict[int, dict[str, nn.Parameter]] = {}
205+
self._handles: list[RemovableHandle] = []
206+
207+
def __enter__(self) -> list[tuple[int, nn.Parameter]]:
208+
"""Register hooks and return list of (orig_param_id, clone_param)."""
209+
210+
def pre_hook(module: nn.Module, _) -> None:
211+
saved: dict[str, nn.Parameter] = {}
212+
for name, orig_param in module.named_parameters():
213+
if orig_param is None or not orig_param.requires_grad:
214+
continue
215+
clone_tensor = orig_param.detach().clone().requires_grad_()
216+
clone_param = nn.Parameter(clone_tensor)
217+
saved[name] = orig_param
218+
setattr(module, name, clone_param)
219+
self.usage_clones.append((id(orig_param), clone_param))
220+
self._orig_params_storage[id(module)] = saved
221+
222+
def post_hook(module: nn.Module, _, __) -> None:
223+
self._restore_original_params(module)
224+
225+
# Register hooks on all modules with direct trainable params
226+
for mod in self.model.modules():
227+
if any(p.requires_grad for p in mod.parameters(recurse=False)):
228+
self._handles.append(mod.register_forward_pre_hook(pre_hook))
229+
self._handles.append(mod.register_forward_hook(post_hook))
230+
231+
return self.usage_clones
232+
233+
def __exit__(self, exc_type, exc_val, exc_tb):
234+
"""Remove hooks and restore parameters."""
235+
for handle in self._handles:
236+
handle.remove()
237+
for module in self.model.modules():
238+
self._restore_original_params(module)
239+
240+
return False # don’t suppress exceptions
241+
242+
def _restore_original_params(self, module: nn.Module):
243+
saved = self._orig_params_storage.get(id(module), {})
244+
for name, orig_param in saved.items():
245+
setattr(module, name, orig_param)
246+
if id(module) in self._orig_params_storage:
247+
del self._orig_params_storage[id(module)]

0 commit comments

Comments
 (0)