Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/examples/iwmtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The following example shows how to do that.
optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
engine = Engine(shared_module, batch_dim=0)
engine = Engine(shared_module)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/iwrm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
params = model.parameters()
optimizer = SGD(params, lr=0.1)
weighting = UPGradWeighting()
engine = Engine(model, batch_dim=0)
engine = Engine(model)

for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/partial_jd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.

# Create the autogram engine that will compute the Gramian of the
# Jacobian with respect to the two last Linear layers' parameters.
engine = Engine(model[2:], batch_dim=0)
engine = Engine(model[2:])

params = model.parameters()
optimizer = SGD(params, lr=0.1)
Expand Down
210 changes: 41 additions & 169 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,9 @@
from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
from ._gramian_utils import movedim_gramian, reshape_gramian
from ._jacobian_computer import (
AutogradJacobianComputer,
FunctionalJacobianComputer,
JacobianComputer,
)
from ._jacobian_computer import AutogradJacobianComputer
from ._module_hook_manager import ModuleHookManager

_MODULES_INCOMPATIBLE_WITH_BATCHED = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LazyBatchNorm1d,
nn.LazyBatchNorm2d,
nn.LazyBatchNorm3d,
nn.SyncBatchNorm,
nn.RNNBase,
)

_TRACK_RUNNING_STATS_MODULE_TYPES = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LazyBatchNorm1d,
nn.LazyBatchNorm2d,
nn.LazyBatchNorm3d,
nn.SyncBatchNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
nn.LazyInstanceNorm1d,
nn.LazyInstanceNorm2d,
nn.LazyInstanceNorm3d,
)


class Engine:
"""
Expand All @@ -50,7 +18,7 @@ class Engine:
Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_ but goes even further:

* It works for any computation graph (not just sequential models).
* It is optimized for batched computations (as long as ``batch_dim`` is specified).
* It is highly optimized for batched computations but also supports non-batched computations.
* It supports any shape of tensor to differentiate (not just a vector of losses). For more
details about this, look at :meth:`Engine.compute_gramian`.

Expand All @@ -66,10 +34,6 @@ class Engine:
:param modules: The modules whose parameters will contribute to the Gramian of the Jacobian.
Several modules can be provided, but it's important that none of them is a child module of
another of them.
:param batch_dim: If the modules work with batches and process each batch element independently,
then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial
memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates
the batch dimension of the output tensor, if any.

.. admonition::
Example
Expand Down Expand Up @@ -97,7 +61,7 @@ class Engine:
weighting = UPGradWeighting()

# Create the engine before the backward pass, and only once.
engine = Engine(model, batch_dim=0)
engine = Engine(model)

for input, target in zip(inputs, targets):
output = model(input).squeeze(dim=1) # shape: [16]
Expand All @@ -113,48 +77,13 @@ class Engine:
since the Jacobian never has to be entirely in memory, it is often much more
memory-efficient, and thus typically faster, to use the Gramian-based approach.

.. warning:: For autogram to be fast and low-memory, it is very important to use only batched
modules (i.e. modules that treat each element of the batch independently). For instance,
BatchNorm is not a batched module because it computes some statistics over the batch.

.. warning::
When providing a non-None ``batch_dim``, all provided modules must respect a few conditions:

* They should treat the elements of the batch independently. Most common layers respect
this, but for example `BatchNorm
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
computes some average and standard deviation over the elements of the batch).
* Their inputs and outputs can be anything, but each input tensor and each output tensor
must be batched on its first dimension. When available (e.g. in `Transformers
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_,
`MultiheadAttention
<https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html>`_,
etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ not supported yet
because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``.
* They should not perform in-place operations on tensors (for instance you should not use
``track_running_stats=True`` in normalization layers).
* They should not have side effects during the forward pass (since their forward pass will
be called twice, the side effects could be different from what's expected).
* If they have some randomness during the forward pass, they should not have direct
trainable parameters. For this reason,
`Transformers
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_, which use a
dropout function (rather than a `Dropout
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layer) in a
module with some trainable parameters, has to be used with
``dropout=0.0``. Note that a `Dropout
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layers are
entirely supported and should be preferred. It is also perfectly fine for random modules
to have child modules that have trainable parameters, so if you have a random module with
some direct parameters, a simple fix is to wrap these parameters into a child module.

If you're building your own architecture, respecting those criteria should be quite easy.
However, if you're using an existing architecture, you may have to modify it to make it
compatible with the autogram engine. For instance, you may want to replace `BatchNorm2d
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ layers by
`GroupNorm <https://docs.pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html>`_ or
`InstanceNorm2d
<https://docs.pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html>`_ layers.

The alternative is to use ``batch_dim=None``, but it's not recommended since it will
increase memory usage by a lot and thus typically slow down computation.
`RNNs <https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ may not be
supported on cuda because vmap is not implemented for RNN on that device.

.. warning::
Parent modules should call their child modules directly rather than using their child
Expand All @@ -177,14 +106,9 @@ class Engine:
another child module to avoid the slow-down.
"""

def __init__(
self,
*modules: nn.Module,
batch_dim: int | None,
):
def __init__(self, *modules: nn.Module):
self._gramian_accumulator = GramianAccumulator()
self._target_edges = EdgeRegistry()
self._batch_dim = batch_dim
self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator)
self._gramian_computers = dict[nn.Module, GramianComputer]()

Expand All @@ -193,7 +117,6 @@ def __init__(

def _hook_module_recursively(self, module: nn.Module) -> None:
if any(p.requires_grad for p in module.parameters(recurse=False)):
self._check_module_is_compatible(module)
gramian_computer = self._make_gramian_computer(module)
self._gramian_computers[module] = gramian_computer
self._module_hook_manager.hook_module(module, gramian_computer)
Expand All @@ -202,36 +125,11 @@ def _hook_module_recursively(self, module: nn.Module) -> None:
self._hook_module_recursively(child)

def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
jacobian_computer: JacobianComputer
if self._batch_dim is not None:
jacobian_computer = FunctionalJacobianComputer(module)
else:
jacobian_computer = AutogradJacobianComputer(module)
jacobian_computer = AutogradJacobianComputer(module)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)

return gramian_computer

def _check_module_is_compatible(self, module: nn.Module) -> None:
if self._batch_dim is not None:
if isinstance(module, _MODULES_INCOMPATIBLE_WITH_BATCHED):
raise ValueError(
f"Found a module of type {type(module)}, which is incompatible with the "
f"autogram engine when `batch_dim` is not `None`. The incompatible module types"
f" are {_MODULES_INCOMPATIBLE_WITH_BATCHED} (and their subclasses). The "
f"recommended fix is to replace incompatible layers by something else (e.g. "
f"BatchNorm by InstanceNorm). If you really can't and performance is not a "
f"priority, you may also just set `batch_dim=None` when creating the engine."
)
if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats:
raise ValueError(
f"Found a module of type {type(module)}, with `track_running_stats=True`, which"
f" is incompatible with the autogram engine when `batch_dim` is not `None`, due"
f" to performing in-place operations on tensors and having side-effects during "
f"the forward pass. Try setting `track_running_stats` to `False`. If you really"
f" can't and performance is not a priority, you may also just set "
f"`batch_dim=None` when creating the engine."
)

def compute_gramian(self, output: Tensor) -> Tensor:
r"""
Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of
Expand Down Expand Up @@ -261,33 +159,31 @@ def compute_gramian(self, output: Tensor) -> Tensor:
- etc.
"""

if self._batch_dim is not None:
# move batched dim to the end
ordered_output = output.movedim(self._batch_dim, -1)
ordered_shape = list(ordered_output.shape)
batch_size = ordered_shape[-1]
has_non_batch_dim = len(ordered_shape) > 1
target_shape = [batch_size]
else:
ordered_output = output
ordered_shape = list(ordered_output.shape)
has_non_batch_dim = len(ordered_shape) > 0
target_shape = []
self._module_hook_manager.gramian_accumulation_phase.value = True

try:
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))

def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
return torch.autograd.grad(
outputs=output,
inputs=leaf_targets,
grad_outputs=_grad_output,
retain_graph=True,
)

if has_non_batch_dim:
target_shape = [-1] + target_shape
output_dims = list(range(output.ndim))
jac_output = _make_initial_jac_output(output)

reshaped_output = ordered_output.reshape(target_shape)
# There are four different cases for the shape of reshaped_output:
# - Not batched and not non-batched: scalar of shape []
# - Batched only: vector of shape [batch_size]
# - Non-batched only: vector of shape [dim]
# - Batched and non-batched: matrix of shape [dim, batch_size]
vmapped_diff = differentiation
for _ in output_dims:
vmapped_diff = vmap(vmapped_diff)

self._module_hook_manager.gramian_accumulation_phase.value = True
_ = vmapped_diff(jac_output)

try:
square_gramian = self._compute_square_gramian(reshaped_output, has_non_batch_dim)
# If the gramian were None, then leaf_targets would be empty, so autograd.grad would
# have failed. So gramian is necessarily a valid Tensor here.
gramian = cast(Tensor, self._gramian_accumulator.gramian)
finally:
# Reset everything that has a state, even if the previous call raised an exception
self._module_hook_manager.gramian_accumulation_phase.value = False
Expand All @@ -296,40 +192,16 @@ def compute_gramian(self, output: Tensor) -> Tensor:
for gramian_computer in self._gramian_computers.values():
gramian_computer.reset()

unordered_gramian = reshape_gramian(square_gramian, ordered_shape)

if self._batch_dim is not None:
gramian = movedim_gramian(unordered_gramian, [-1], [self._batch_dim])
else:
gramian = unordered_gramian

return gramian

def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> Tensor:
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))

def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
return torch.autograd.grad(
outputs=output,
inputs=leaf_targets,
grad_outputs=_grad_output,
retain_graph=True,
)

if has_non_batch_dim:
# There is one non-batched dimension, it is the first one
non_batch_dim_len = output.shape[0]
identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype)
ones = torch.ones_like(output[0])
jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones)

_ = vmap(differentiation)(jac_output)
else:
grad_output = torch.ones_like(output)
_ = differentiation(grad_output)

# If the gramian were None, then leaf_targets would be empty, so autograd.grad would
# have failed. So gramian is necessarily a valid Tensor here.
gramian = cast(Tensor, self._gramian_accumulator.gramian)
def _make_initial_jac_output(output: Tensor) -> Tensor:
if output.ndim == 0:
return torch.ones_like(output)
p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape]
p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij")
v_indices_grid = p_indices_grid + p_indices_grid

return gramian
res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype)
res[v_indices_grid] = 1.0
return res
24 changes: 11 additions & 13 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import Tensor
from torch.utils._pytree import PyTree

from torchjd.autogram._jacobian_computer import JacobianComputer

Expand All @@ -13,8 +13,6 @@ def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Optional[Tensor]:
"""Compute what we can for a module and optionally return the gramian if it's ready."""

Expand All @@ -30,8 +28,12 @@ def __init__(self, jacobian_computer):
self.jacobian_computer = jacobian_computer

@staticmethod
def _to_gramian(jacobian: Tensor) -> Tensor:
return jacobian @ jacobian.T
def _to_gramian(matrix: Tensor) -> Tensor:
"""Contracts the last dimension of matrix to make it into a Gramian."""

indices = list(range(matrix.ndim))
transposed_matrix = matrix.movedim(indices, indices[::-1])
return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0]))


class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
Expand All @@ -53,20 +55,16 @@ def track_forward_call(self) -> None:
self.remaining_counter += 1

def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...]
) -> Optional[Tensor]:
"""Compute what we can for a module and optionally return the gramian if it's ready."""

jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
jacobian = self.jacobian_computer(rg_outputs, grad_outputs)

if self.summed_jacobian is None:
self.summed_jacobian = jacobian_matrix
self.summed_jacobian = jacobian
else:
self.summed_jacobian += jacobian_matrix
self.summed_jacobian += jacobian

self.remaining_counter -= 1

Expand Down
Loading
Loading