|
1 | 1 | from abc import ABC, abstractmethod |
| 2 | +from collections.abc import Callable |
2 | 3 | from typing import Optional |
3 | 4 |
|
4 | | -from torch import Tensor |
| 5 | +import torch |
| 6 | +from torch import Tensor, nn |
5 | 7 | from torch.utils._pytree import PyTree |
6 | 8 |
|
| 9 | +from torchjd.autogram._gramian_utils import reshape_gramian |
7 | 10 | from torchjd.autogram._jacobian_computer import JacobianComputer |
8 | 11 |
|
9 | 12 |
|
@@ -76,3 +79,72 @@ def __call__( |
76 | 79 | return gramian |
77 | 80 | else: |
78 | 81 | return None |
| 82 | + |
| 83 | + |
| 84 | +class LinearBasedGramianComputer(GramianComputer): |
| 85 | + def __init__(self, module: nn.Linear): |
| 86 | + self.module = module |
| 87 | + |
| 88 | + def __call__( |
| 89 | + self, |
| 90 | + _: tuple[Tensor, ...], |
| 91 | + grad_outputs: tuple[Tensor, ...], |
| 92 | + args: tuple[PyTree, ...], |
| 93 | + __: dict[str, PyTree], |
| 94 | + ) -> Optional[Tensor]: |
| 95 | + |
| 96 | + X = args[0] |
| 97 | + dY = grad_outputs[0] |
| 98 | + |
| 99 | + gramian = ComputeLinearGramian.apply(self._compute_gramian, dY, X) |
| 100 | + return gramian |
| 101 | + |
| 102 | + def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor: |
| 103 | + """ |
| 104 | + X is a matrix of shape [k, n] and dY1, dY2 are matrices of shape [k, m]. |
| 105 | + Returns the dY1 @ G @ dY2 where G is the Gramian of the Jacobian of the module output w.r.t. |
| 106 | + to the module params. |
| 107 | + """ |
| 108 | + |
| 109 | + G_b = torch.einsum("ik,jk->ij", dY1, dY2) |
| 110 | + G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2) |
| 111 | + |
| 112 | + return G_b + G_W |
| 113 | + |
| 114 | + |
| 115 | +class ComputeLinearGramian(torch.autograd.Function): |
| 116 | + @staticmethod |
| 117 | + def forward( |
| 118 | + compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor], |
| 119 | + dY: Tensor, |
| 120 | + X: Tensor, |
| 121 | + ) -> Tensor: |
| 122 | + # There is no non-batched dimension |
| 123 | + gramian = compute_gramian_fn(dY, dY, X) |
| 124 | + return gramian |
| 125 | + |
| 126 | + @staticmethod |
| 127 | + def vmap( |
| 128 | + _, |
| 129 | + in_dims: tuple[None, tuple[int, ...], None], |
| 130 | + compute_gramian_fn: Callable[[Tensor, Tensor, Tensor], Tensor], |
| 131 | + dY: Tensor, |
| 132 | + X: Tensor, |
| 133 | + ) -> tuple[Tensor, None]: |
| 134 | + # There is a non-batched dimension |
| 135 | + generalized_gramian = torch.vmap( |
| 136 | + torch.vmap( |
| 137 | + compute_gramian_fn, |
| 138 | + in_dims=(in_dims[1], None, None), |
| 139 | + out_dims=0, |
| 140 | + ), |
| 141 | + in_dims=(None, in_dims[1], None), |
| 142 | + out_dims=-1, |
| 143 | + )(dY, dY, X) |
| 144 | + shape = dY.shape |
| 145 | + gramian = reshape_gramian(generalized_gramian, [shape[0] * shape[1]]) |
| 146 | + return gramian, None |
| 147 | + |
| 148 | + @staticmethod |
| 149 | + def setup_context(*_) -> None: |
| 150 | + pass |
0 commit comments