|
4 | 4 | from torch import Tensor, nn, vmap |
5 | 5 | from torch.nn.functional import mse_loss |
6 | 6 | from torch.utils._pytree import PyTree, tree_flatten, tree_map |
| 7 | +from torch.utils.hooks import RemovableHandle |
7 | 8 | from utils.architectures import get_in_out_shapes |
8 | 9 | from utils.contexts import fork_rng |
9 | 10 |
|
@@ -138,9 +139,109 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]: |
138 | 139 | return gramian |
139 | 140 |
|
140 | 141 |
|
| 142 | +def compute_gramian_with_autograd_no_cross_terms( |
| 143 | + model: nn.Module, |
| 144 | + inputs: PyTree, |
| 145 | + loss_fn: Callable[[PyTree], list[Tensor]], |
| 146 | +): |
| 147 | + with CloneParams(model) as usage_clones: |
| 148 | + output = model(inputs) |
| 149 | + |
| 150 | + _, expected_output_shapes = get_in_out_shapes(model) |
| 151 | + assert tree_map(lambda t: t.shape[1:], output) == expected_output_shapes |
| 152 | + |
| 153 | + loss_tensors = loss_fn(output) |
| 154 | + loss_vector = reduce_to_vector(loss_tensors) |
| 155 | + |
| 156 | + def get_vjp(grad_outputs: Tensor) -> list[Tensor]: |
| 157 | + grads = torch.autograd.grad( |
| 158 | + loss_vector, |
| 159 | + [cloned_param for orig_param_id, cloned_param in usage_clones], |
| 160 | + grad_outputs=grad_outputs, |
| 161 | + retain_graph=False, |
| 162 | + allow_unused=True, |
| 163 | + ) |
| 164 | + return [grad for grad in grads if grad is not None] |
| 165 | + |
| 166 | + jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(loss_vector))) |
| 167 | + jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] |
| 168 | + |
| 169 | + gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) |
| 170 | + |
| 171 | + return gramian |
| 172 | + |
| 173 | + |
141 | 174 | def compute_gramian(matrix: Tensor) -> Tensor: |
142 | 175 | """Contracts the last dimension of matrix to make it into a Gramian.""" |
143 | 176 |
|
144 | 177 | indices = list(range(matrix.ndim)) |
145 | 178 | transposed_matrix = matrix.movedim(indices, indices[::-1]) |
146 | 179 | return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0])) |
| 180 | + |
| 181 | + |
| 182 | +class CloneParams: |
| 183 | + """ |
| 184 | + ContextManager enabling the computation of per-usage gradients. |
| 185 | +
|
| 186 | + For each submodule with direct trainable parameters, registers: |
| 187 | + - A pre-hook that clones the params before using them, so that gradients will be computed with |
| 188 | + respect to the cloned params. |
| 189 | + - A post-hook that restores the original params. |
| 190 | +
|
| 191 | + The list of clones is returned so that we know where to find the .grad values corresponding to |
| 192 | + each individual usage of a parameter. |
| 193 | +
|
| 194 | + Exiting this context manager takes care of removing hooks and restoring the original params (in |
| 195 | + case an exception occurred before the post-hook could do it). |
| 196 | +
|
| 197 | + Note that this does not work for intra-module parameter reuse, which would require a node-based |
| 198 | + algorithm rather than a module-based algorithm. |
| 199 | + """ |
| 200 | + |
| 201 | + def __init__(self, model: nn.Module): |
| 202 | + self.model = model |
| 203 | + self.usage_clones: list[tuple[int, nn.Parameter]] = [] |
| 204 | + self._orig_params_storage: dict[int, dict[str, nn.Parameter]] = {} |
| 205 | + self._handles: list[RemovableHandle] = [] |
| 206 | + |
| 207 | + def __enter__(self) -> list[tuple[int, nn.Parameter]]: |
| 208 | + """Register hooks and return list of (orig_param_id, clone_param).""" |
| 209 | + |
| 210 | + def pre_hook(module: nn.Module, _) -> None: |
| 211 | + saved: dict[str, nn.Parameter] = {} |
| 212 | + for name, orig_param in module.named_parameters(): |
| 213 | + if orig_param is None or not orig_param.requires_grad: |
| 214 | + continue |
| 215 | + clone_tensor = orig_param.detach().clone().requires_grad_() |
| 216 | + clone_param = nn.Parameter(clone_tensor) |
| 217 | + saved[name] = orig_param |
| 218 | + setattr(module, name, clone_param) |
| 219 | + self.usage_clones.append((id(orig_param), clone_param)) |
| 220 | + self._orig_params_storage[id(module)] = saved |
| 221 | + |
| 222 | + def post_hook(module: nn.Module, _, __) -> None: |
| 223 | + self._restore_original_params(module) |
| 224 | + |
| 225 | + # Register hooks on all modules with direct trainable params |
| 226 | + for mod in self.model.modules(): |
| 227 | + if any(p.requires_grad for p in mod.parameters(recurse=False)): |
| 228 | + self._handles.append(mod.register_forward_pre_hook(pre_hook)) |
| 229 | + self._handles.append(mod.register_forward_hook(post_hook)) |
| 230 | + |
| 231 | + return self.usage_clones |
| 232 | + |
| 233 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 234 | + """Remove hooks and restore parameters.""" |
| 235 | + for handle in self._handles: |
| 236 | + handle.remove() |
| 237 | + for module in self.model.modules(): |
| 238 | + self._restore_original_params(module) |
| 239 | + |
| 240 | + return False # don’t suppress exceptions |
| 241 | + |
| 242 | + def _restore_original_params(self, module: nn.Module): |
| 243 | + saved = self._orig_params_storage.get(id(module), {}) |
| 244 | + for name, orig_param in saved.items(): |
| 245 | + setattr(module, name, orig_param) |
| 246 | + if id(module) in self._orig_params_storage: |
| 247 | + del self._orig_params_storage[id(module)] |
0 commit comments