|
| 1 | +from abc import ABC, abstractmethod |
| 2 | +from collections.abc import Callable |
| 3 | +from typing import cast |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import Tensor, nn |
| 7 | +from torch.nn import Parameter |
| 8 | +from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only |
| 9 | + |
| 10 | +# Note about import from protected _pytree module: |
| 11 | +# PyTorch maintainers plan to make pytree public (see |
| 12 | +# https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400). |
| 13 | +# It should also come with better speed, because the current implementation is slow, according to |
| 14 | +# https://github.com/pytorch/pytorch/issues/65761#issue-1010116111. |
| 15 | +# When pytree becomes public, this import will have to be changed with a conditional import (to |
| 16 | +# still support older versions of PyTorch where pytree is protected). |
| 17 | + |
| 18 | + |
| 19 | +class JacobianComputer(ABC): |
| 20 | + """ |
| 21 | + Abstract class to computes Jacobians for a module's forward pass with respect to its parameters. |
| 22 | +
|
| 23 | + :params module: The module to differentiate. |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__(self, module: nn.Module): |
| 27 | + self.module = module |
| 28 | + |
| 29 | + self.rg_params = dict[str, Parameter]() |
| 30 | + self.frozen_params = dict[str, Parameter]() |
| 31 | + |
| 32 | + for name, param in module.named_parameters(recurse=True): |
| 33 | + if param.requires_grad: |
| 34 | + self.rg_params[name] = param |
| 35 | + else: |
| 36 | + self.frozen_params[name] = param |
| 37 | + |
| 38 | + def __call__( |
| 39 | + self, |
| 40 | + rg_outputs: tuple[Tensor, ...], |
| 41 | + grad_outputs: tuple[Tensor, ...], |
| 42 | + args: tuple[PyTree, ...], |
| 43 | + kwargs: dict[str, PyTree], |
| 44 | + ) -> Tensor: |
| 45 | + # This makes __call__ vmappable. |
| 46 | + return ComputeModuleJacobians.apply( |
| 47 | + self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs |
| 48 | + ) |
| 49 | + |
| 50 | + @abstractmethod |
| 51 | + def _compute_jacobian( |
| 52 | + self, |
| 53 | + rg_outputs: tuple[Tensor, ...], |
| 54 | + grad_outputs: tuple[Tensor, ...], |
| 55 | + args: tuple[PyTree, ...], |
| 56 | + kwargs: dict[str, PyTree], |
| 57 | + ) -> Tensor: |
| 58 | + """ |
| 59 | + Computes and returns the Jacobian. The output must be a matrix (2D Tensor). |
| 60 | + """ |
| 61 | + |
| 62 | + |
| 63 | +class FunctionalJacobianComputer(JacobianComputer): |
| 64 | + """ |
| 65 | + JacobianComputer using the functional differentiation API. This requires to use vmap, so it's |
| 66 | + not compatible with every module, and it requires to have an extra forward pass to create the |
| 67 | + vjp function. |
| 68 | + """ |
| 69 | + |
| 70 | + def _compute_jacobian( |
| 71 | + self, |
| 72 | + _: tuple[Tensor, ...], |
| 73 | + grad_outputs: tuple[Tensor, ...], |
| 74 | + args: tuple[PyTree, ...], |
| 75 | + kwargs: dict[str, PyTree], |
| 76 | + ) -> Tensor: |
| 77 | + grad_outputs_in_dims = (0,) * len(grad_outputs) |
| 78 | + args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) |
| 79 | + kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) |
| 80 | + in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims) |
| 81 | + vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) |
| 82 | + |
| 83 | + return vmapped_vjp(grad_outputs, args, kwargs) |
| 84 | + |
| 85 | + def _call_on_one_instance( |
| 86 | + self, |
| 87 | + grad_outputs_j: tuple[Tensor, ...], |
| 88 | + args_j: tuple[PyTree, ...], |
| 89 | + kwargs_j: dict[str, PyTree], |
| 90 | + ) -> Tensor: |
| 91 | + # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a |
| 92 | + # "batch" of 1 activation (or grad_output). This is because some layers (e.g. |
| 93 | + # nn.Flatten) do not work equivalently if they're provided with a batch or with |
| 94 | + # an element of a batch. We thus always provide them with batches, just of a |
| 95 | + # different size. |
| 96 | + args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) |
| 97 | + kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j) |
| 98 | + grad_outputs_j_ = tuple(x.unsqueeze(0) for x in grad_outputs_j) |
| 99 | + |
| 100 | + def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]: |
| 101 | + all_state = [ |
| 102 | + cast(dict[str, Tensor], rg_params), |
| 103 | + dict(self.module.named_buffers()), |
| 104 | + cast(dict[str, Tensor], self.frozen_params), |
| 105 | + ] |
| 106 | + output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) |
| 107 | + flat_outputs = tree_flatten(output)[0] |
| 108 | + rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad) |
| 109 | + return rg_outputs |
| 110 | + |
| 111 | + vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] |
| 112 | + |
| 113 | + # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the |
| 114 | + # functional has a single primal which is dict(module.named_parameters()). We therefore take |
| 115 | + # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. |
| 116 | + gradients = vjp_func(grad_outputs_j_)[0] |
| 117 | + gradient = torch.cat([t.reshape(-1) for t in gradients.values()]) |
| 118 | + return gradient |
| 119 | + |
| 120 | + |
| 121 | +class AutogradJacobianComputer(JacobianComputer): |
| 122 | + """ |
| 123 | + JacobianComputer using the autograd engine. The main advantage of using this method is that it |
| 124 | + doesn't require making an extra forward pass. |
| 125 | + """ |
| 126 | + |
| 127 | + def _compute_jacobian( |
| 128 | + self, |
| 129 | + rg_outputs: tuple[Tensor, ...], |
| 130 | + grad_outputs: tuple[Tensor, ...], |
| 131 | + _: tuple[PyTree, ...], |
| 132 | + __: dict[str, PyTree], |
| 133 | + ) -> Tensor: |
| 134 | + flat_rg_params, ___ = tree_flatten(self.rg_params) |
| 135 | + grads = torch.autograd.grad( |
| 136 | + rg_outputs, |
| 137 | + flat_rg_params, |
| 138 | + grad_outputs, |
| 139 | + retain_graph=True, |
| 140 | + allow_unused=True, |
| 141 | + materialize_grads=True, |
| 142 | + ) |
| 143 | + flattened_grads = torch.cat([g.reshape(-1) for g in grads]) |
| 144 | + jacobian = flattened_grads.unsqueeze(0) |
| 145 | + return jacobian |
| 146 | + |
| 147 | + |
| 148 | +class ComputeModuleJacobians(torch.autograd.Function): |
| 149 | + @staticmethod |
| 150 | + def forward( |
| 151 | + compute_jacobian_fn: Callable[ |
| 152 | + [tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Tensor |
| 153 | + ], |
| 154 | + rg_outputs: tuple[Tensor, ...], |
| 155 | + grad_outputs: tuple[Tensor, ...], |
| 156 | + args: tuple[PyTree, ...], |
| 157 | + kwargs: dict[str, PyTree], |
| 158 | + ) -> Tensor: |
| 159 | + # There is no non-batched dimension |
| 160 | + jacobian = compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs) |
| 161 | + return jacobian |
| 162 | + |
| 163 | + @staticmethod |
| 164 | + def vmap( |
| 165 | + _, |
| 166 | + in_dims: tuple[None, None, tuple[int, ...], None, None], |
| 167 | + compute_jacobian_fn: Callable, |
| 168 | + rg_outputs: tuple[Tensor, ...], |
| 169 | + jac_outputs: tuple[Tensor, ...], |
| 170 | + args: tuple[PyTree, ...], |
| 171 | + kwargs: dict[str, PyTree], |
| 172 | + ) -> tuple[Tensor, None]: |
| 173 | + # There is a non-batched dimension |
| 174 | + # We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension |
| 175 | + generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])( |
| 176 | + rg_outputs, |
| 177 | + jac_outputs, |
| 178 | + args, |
| 179 | + kwargs, |
| 180 | + ) |
| 181 | + shape = generalized_jacobian.shape |
| 182 | + jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) |
| 183 | + return jacobian, None |
| 184 | + |
| 185 | + @staticmethod |
| 186 | + def setup_context(*_) -> None: |
| 187 | + pass |
0 commit comments