|
1 | 1 | r""" |
2 | 2 | When doing Jacobian descent, the Jacobian matrix has to be aggregated into a vector to store in the |
3 | 3 | ``.grad`` fields of the model parameters. The |
4 | | -The :class:`~torchjd.aggregation._aggregator_bases.Aggregator` is responsible for these |
5 | | -aggregations. |
| 4 | +:class:`~torchjd.aggregation._aggregator_bases.Aggregator` is responsible for these aggregations. |
6 | 5 |
|
7 | 6 | When using the :doc:`autogram <../autogram/index>` engine, we rather need to extract a vector |
8 | 7 | of weights from the Gramian of the Jacobian. The |
|
20 | 19 | :class:`Aggregators <torchjd.aggregation._aggregator_bases.Aggregator>` and :class:`Weightings |
21 | 20 | <torchjd.aggregation._weighting_bases.Weighting>` are callables that take a Jacobian matrix or a |
22 | 21 | Gramian matrix as inputs, respectively. The following example shows how to use UPGrad to either |
23 | | -aggregate a Jacobian or obtain the weights from the Gramian of the Jacobian. |
| 22 | +aggregate a Jacobian (of shape ``[m, n]``, where ``m`` is the number of objectives and ``n`` is the |
| 23 | +number of parameters), or obtain the weights from the Gramian of the Jacobian (of shape ``[m, m]``). |
24 | 24 |
|
25 | 25 | >>> from torch import tensor |
26 | 26 | >>> from torchjd.aggregation import UPGrad, UPGradWeighting |
|
35 | 35 | >>> weights = weighting(gramian) |
36 | 36 | >>> weights |
37 | 37 | tensor([1.1109, 0.7894]) |
| 38 | +
|
| 39 | +When dealing with a more general tensor of objectives, of shape ``[m_1, ..., m_k]`` (i.e. not |
| 40 | +necessarily a simple vector), the Jacobian will be of shape ``[m_1, ..., m_k, n]``, and its Gramian |
| 41 | +will be called a `generalized Gramian`, of shape ``[m_1, ..., m_k, m_k, ..., m_1]``. One can use a |
| 42 | +:class:`GeneralizedWeighting<torchjd.aggregation._weighting_bases.GeneralizedWeighting>` to extract |
| 43 | +a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest |
| 44 | +:class:`GeneralizedWeighting<torchjd.aggregation._weighting_bases.GeneralizedWeighting>` is |
| 45 | +:class:`Flattening<torchjd.aggregation._flattening.Flattening>`: it simply "flattens" the |
| 46 | +generalized Gramian into a square Gramian matrix (of shape ``[m_1 * ... * m_k, m_1 * ... * m_k]``), |
| 47 | +applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of |
| 48 | +weights. |
| 49 | +
|
| 50 | +>>> from torch import ones |
| 51 | +>>> from torchjd.aggregation import Flattening, UPGradWeighting |
| 52 | +>>> |
| 53 | +>>> weighting = Flattening(UPGradWeighting()) |
| 54 | +>>> # Generate a generalized Gramian filled with ones, for the sake of the example |
| 55 | +>>> generalized_gramian = ones((2, 3, 3, 2)) |
| 56 | +>>> weights = weighting(generalized_gramian) |
| 57 | +>>> weights |
| 58 | +tensor([[0.1667, 0.1667, 0.1667], |
| 59 | + [0.1667, 0.1667, 0.1667]]) |
38 | 60 | """ |
39 | 61 |
|
40 | 62 | from ._aggregator_bases import Aggregator |
41 | 63 | from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting |
42 | 64 | from ._config import ConFIG |
43 | 65 | from ._constant import Constant, ConstantWeighting |
44 | 66 | from ._dualproj import DualProj, DualProjWeighting |
| 67 | +from ._flattening import Flattening |
45 | 68 | from ._graddrop import GradDrop |
46 | 69 | from ._imtl_g import IMTLG, IMTLGWeighting |
47 | 70 | from ._krum import Krum, KrumWeighting |
|
55 | 78 | from ._utils.check_dependencies import ( |
56 | 79 | OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, |
57 | 80 | ) |
58 | | -from ._weighting_bases import Weighting |
| 81 | +from ._weighting_bases import GeneralizedWeighting, Weighting |
59 | 82 |
|
60 | 83 | try: |
61 | 84 | from ._cagrad import CAGrad, CAGradWeighting |
|
0 commit comments