diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index 4c1c7a4c8..ee59dbe73 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -31,7 +31,7 @@ The following example shows how to do that. optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module, batch_dim=0) + engine = Engine(shared_module) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index a326f582e..75841eb0c 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model, batch_dim=0) + engine = Engine(model) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] diff --git a/docs/source/examples/partial_jd.rst b/docs/source/examples/partial_jd.rst index c86a653aa..091559161 100644 --- a/docs/source/examples/partial_jd.rst +++ b/docs/source/examples/partial_jd.rst @@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time. # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:], batch_dim=0) + engine = Engine(model[2:]) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 227878968..361743a40 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -7,41 +7,9 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms -from ._gramian_utils import movedim_gramian, reshape_gramian -from ._jacobian_computer import ( - AutogradJacobianComputer, - FunctionalJacobianComputer, - JacobianComputer, -) +from ._jacobian_computer import AutogradJacobianComputer from ._module_hook_manager import ModuleHookManager -_MODULES_INCOMPATIBLE_WITH_BATCHED = ( - nn.BatchNorm1d, - nn.BatchNorm2d, - nn.BatchNorm3d, - nn.LazyBatchNorm1d, - nn.LazyBatchNorm2d, - nn.LazyBatchNorm3d, - nn.SyncBatchNorm, - nn.RNNBase, -) - -_TRACK_RUNNING_STATS_MODULE_TYPES = ( - nn.BatchNorm1d, - nn.BatchNorm2d, - nn.BatchNorm3d, - nn.LazyBatchNorm1d, - nn.LazyBatchNorm2d, - nn.LazyBatchNorm3d, - nn.SyncBatchNorm, - nn.InstanceNorm1d, - nn.InstanceNorm2d, - nn.InstanceNorm3d, - nn.LazyInstanceNorm1d, - nn.LazyInstanceNorm2d, - nn.LazyInstanceNorm3d, -) - class Engine: """ @@ -50,7 +18,7 @@ class Engine: Multi-Objective Optimization `_ but goes even further: * It works for any computation graph (not just sequential models). - * It is optimized for batched computations (as long as ``batch_dim`` is specified). + * It is highly optimized for batched computations but also supports non-batched computations. * It supports any shape of tensor to differentiate (not just a vector of losses). For more details about this, look at :meth:`Engine.compute_gramian`. @@ -66,10 +34,6 @@ class Engine: :param modules: The modules whose parameters will contribute to the Gramian of the Jacobian. Several modules can be provided, but it's important that none of them is a child module of another of them. - :param batch_dim: If the modules work with batches and process each batch element independently, - then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial - memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates - the batch dimension of the output tensor, if any. .. admonition:: Example @@ -97,7 +61,7 @@ class Engine: weighting = UPGradWeighting() # Create the engine before the backward pass, and only once. - engine = Engine(model, batch_dim=0) + engine = Engine(model) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] @@ -113,48 +77,13 @@ class Engine: since the Jacobian never has to be entirely in memory, it is often much more memory-efficient, and thus typically faster, to use the Gramian-based approach. + .. warning:: For autogram to be fast and low-memory, it is very important to use only batched + modules (i.e. modules that treat each element of the batch independently). For instance, + BatchNorm is not a batched module because it computes some statistics over the batch. + .. warning:: - When providing a non-None ``batch_dim``, all provided modules must respect a few conditions: - - * They should treat the elements of the batch independently. Most common layers respect - this, but for example `BatchNorm - `_ does not (it - computes some average and standard deviation over the elements of the batch). - * Their inputs and outputs can be anything, but each input tensor and each output tensor - must be batched on its first dimension. When available (e.g. in `Transformers - `_, - `MultiheadAttention - `_, - etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs - `_ not supported yet - because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``. - * They should not perform in-place operations on tensors (for instance you should not use - ``track_running_stats=True`` in normalization layers). - * They should not have side effects during the forward pass (since their forward pass will - be called twice, the side effects could be different from what's expected). - * If they have some randomness during the forward pass, they should not have direct - trainable parameters. For this reason, - `Transformers - `_, which use a - dropout function (rather than a `Dropout - `_ layer) in a - module with some trainable parameters, has to be used with - ``dropout=0.0``. Note that a `Dropout - `_ layers are - entirely supported and should be preferred. It is also perfectly fine for random modules - to have child modules that have trainable parameters, so if you have a random module with - some direct parameters, a simple fix is to wrap these parameters into a child module. - - If you're building your own architecture, respecting those criteria should be quite easy. - However, if you're using an existing architecture, you may have to modify it to make it - compatible with the autogram engine. For instance, you may want to replace `BatchNorm2d - `_ layers by - `GroupNorm `_ or - `InstanceNorm2d - `_ layers. - - The alternative is to use ``batch_dim=None``, but it's not recommended since it will - increase memory usage by a lot and thus typically slow down computation. + `RNNs `_ may not be + supported on cuda because vmap is not implemented for RNN on that device. .. warning:: Parent modules should call their child modules directly rather than using their child @@ -177,14 +106,9 @@ class Engine: another child module to avoid the slow-down. """ - def __init__( - self, - *modules: nn.Module, - batch_dim: int | None, - ): + def __init__(self, *modules: nn.Module): self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() - self._batch_dim = batch_dim self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator) self._gramian_computers = dict[nn.Module, GramianComputer]() @@ -193,7 +117,6 @@ def __init__( def _hook_module_recursively(self, module: nn.Module) -> None: if any(p.requires_grad for p in module.parameters(recurse=False)): - self._check_module_is_compatible(module) gramian_computer = self._make_gramian_computer(module) self._gramian_computers[module] = gramian_computer self._module_hook_manager.hook_module(module, gramian_computer) @@ -202,36 +125,11 @@ def _hook_module_recursively(self, module: nn.Module) -> None: self._hook_module_recursively(child) def _make_gramian_computer(self, module: nn.Module) -> GramianComputer: - jacobian_computer: JacobianComputer - if self._batch_dim is not None: - jacobian_computer = FunctionalJacobianComputer(module) - else: - jacobian_computer = AutogradJacobianComputer(module) + jacobian_computer = AutogradJacobianComputer(module) gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer) return gramian_computer - def _check_module_is_compatible(self, module: nn.Module) -> None: - if self._batch_dim is not None: - if isinstance(module, _MODULES_INCOMPATIBLE_WITH_BATCHED): - raise ValueError( - f"Found a module of type {type(module)}, which is incompatible with the " - f"autogram engine when `batch_dim` is not `None`. The incompatible module types" - f" are {_MODULES_INCOMPATIBLE_WITH_BATCHED} (and their subclasses). The " - f"recommended fix is to replace incompatible layers by something else (e.g. " - f"BatchNorm by InstanceNorm). If you really can't and performance is not a " - f"priority, you may also just set `batch_dim=None` when creating the engine." - ) - if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: - raise ValueError( - f"Found a module of type {type(module)}, with `track_running_stats=True`, which" - f" is incompatible with the autogram engine when `batch_dim` is not `None`, due" - f" to performing in-place operations on tensors and having side-effects during " - f"the forward pass. Try setting `track_running_stats` to `False`. If you really" - f" can't and performance is not a priority, you may also just set " - f"`batch_dim=None` when creating the engine." - ) - def compute_gramian(self, output: Tensor) -> Tensor: r""" Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of @@ -261,33 +159,31 @@ def compute_gramian(self, output: Tensor) -> Tensor: - etc. """ - if self._batch_dim is not None: - # move batched dim to the end - ordered_output = output.movedim(self._batch_dim, -1) - ordered_shape = list(ordered_output.shape) - batch_size = ordered_shape[-1] - has_non_batch_dim = len(ordered_shape) > 1 - target_shape = [batch_size] - else: - ordered_output = output - ordered_shape = list(ordered_output.shape) - has_non_batch_dim = len(ordered_shape) > 0 - target_shape = [] + self._module_hook_manager.gramian_accumulation_phase.value = True + + try: + leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) + + def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: + return torch.autograd.grad( + outputs=output, + inputs=leaf_targets, + grad_outputs=_grad_output, + retain_graph=True, + ) - if has_non_batch_dim: - target_shape = [-1] + target_shape + output_dims = list(range(output.ndim)) + jac_output = _make_initial_jac_output(output) - reshaped_output = ordered_output.reshape(target_shape) - # There are four different cases for the shape of reshaped_output: - # - Not batched and not non-batched: scalar of shape [] - # - Batched only: vector of shape [batch_size] - # - Non-batched only: vector of shape [dim] - # - Batched and non-batched: matrix of shape [dim, batch_size] + vmapped_diff = differentiation + for _ in output_dims: + vmapped_diff = vmap(vmapped_diff) - self._module_hook_manager.gramian_accumulation_phase.value = True + _ = vmapped_diff(jac_output) - try: - square_gramian = self._compute_square_gramian(reshaped_output, has_non_batch_dim) + # If the gramian were None, then leaf_targets would be empty, so autograd.grad would + # have failed. So gramian is necessarily a valid Tensor here. + gramian = cast(Tensor, self._gramian_accumulator.gramian) finally: # Reset everything that has a state, even if the previous call raised an exception self._module_hook_manager.gramian_accumulation_phase.value = False @@ -296,40 +192,16 @@ def compute_gramian(self, output: Tensor) -> Tensor: for gramian_computer in self._gramian_computers.values(): gramian_computer.reset() - unordered_gramian = reshape_gramian(square_gramian, ordered_shape) - - if self._batch_dim is not None: - gramian = movedim_gramian(unordered_gramian, [-1], [self._batch_dim]) - else: - gramian = unordered_gramian - return gramian - def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> Tensor: - leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) - - def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: - return torch.autograd.grad( - outputs=output, - inputs=leaf_targets, - grad_outputs=_grad_output, - retain_graph=True, - ) - - if has_non_batch_dim: - # There is one non-batched dimension, it is the first one - non_batch_dim_len = output.shape[0] - identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype) - ones = torch.ones_like(output[0]) - jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones) - - _ = vmap(differentiation)(jac_output) - else: - grad_output = torch.ones_like(output) - _ = differentiation(grad_output) - # If the gramian were None, then leaf_targets would be empty, so autograd.grad would - # have failed. So gramian is necessarily a valid Tensor here. - gramian = cast(Tensor, self._gramian_accumulator.gramian) +def _make_initial_jac_output(output: Tensor) -> Tensor: + if output.ndim == 0: + return torch.ones_like(output) + p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape] + p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") + v_indices_grid = p_indices_grid + p_indices_grid - return gramian + res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype) + res[v_indices_grid] = 1.0 + return res diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 2bc62f218..470a69d67 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Optional +import torch from torch import Tensor -from torch.utils._pytree import PyTree from torchjd.autogram._jacobian_computer import JacobianComputer @@ -13,8 +13,6 @@ def __call__( self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], ) -> Optional[Tensor]: """Compute what we can for a module and optionally return the gramian if it's ready.""" @@ -30,8 +28,12 @@ def __init__(self, jacobian_computer): self.jacobian_computer = jacobian_computer @staticmethod - def _to_gramian(jacobian: Tensor) -> Tensor: - return jacobian @ jacobian.T + def _to_gramian(matrix: Tensor) -> Tensor: + """Contracts the last dimension of matrix to make it into a Gramian.""" + + indices = list(range(matrix.ndim)) + transposed_matrix = matrix.movedim(indices, indices[::-1]) + return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0])) class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): @@ -53,20 +55,16 @@ def track_forward_call(self) -> None: self.remaining_counter += 1 def __call__( - self, - rg_outputs: tuple[Tensor, ...], - grad_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], + self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...] ) -> Optional[Tensor]: """Compute what we can for a module and optionally return the gramian if it's ready.""" - jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs) + jacobian = self.jacobian_computer(rg_outputs, grad_outputs) if self.summed_jacobian is None: - self.summed_jacobian = jacobian_matrix + self.summed_jacobian = jacobian else: - self.summed_jacobian += jacobian_matrix + self.summed_jacobian += jacobian self.remaining_counter -= 1 diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 26452f5de..21c457034 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -1,11 +1,9 @@ from abc import ABC, abstractmethod -from collections.abc import Callable -from typing import cast import torch from torch import Tensor, nn from torch.nn import Parameter -from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only +from torch.utils._pytree import tree_flatten # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -25,97 +23,27 @@ class JacobianComputer(ABC): def __init__(self, module: nn.Module): self.module = module - self.rg_params = dict[str, Parameter]() - self.frozen_params = dict[str, Parameter]() for name, param in module.named_parameters(recurse=True): if param.requires_grad: self.rg_params[name] = param - else: - self.frozen_params[name] = param - - def __call__( - self, - rg_outputs: tuple[Tensor, ...], - grad_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - ) -> Tensor: - # This makes __call__ vmappable. - return ComputeModuleJacobians.apply( - self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs - ) - - @abstractmethod - def _compute_jacobian( - self, - rg_outputs: tuple[Tensor, ...], - grad_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - ) -> Tensor: - """ - Computes and returns the Jacobian. The output must be a matrix (2D Tensor). - """ - - -class FunctionalJacobianComputer(JacobianComputer): - """ - JacobianComputer using the functional differentiation API. This requires to use vmap, so it's - not compatible with every module, and it requires to have an extra forward pass to create the - vjp function. - """ - - def _compute_jacobian( - self, - _: tuple[Tensor, ...], - grad_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - ) -> Tensor: - grad_outputs_in_dims = (0,) * len(grad_outputs) - args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) - kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) - in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims) - vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) - - return vmapped_vjp(grad_outputs, args, kwargs) - def _call_on_one_instance( - self, - grad_outputs_j: tuple[Tensor, ...], - args_j: tuple[PyTree, ...], - kwargs_j: dict[str, PyTree], - ) -> Tensor: - # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a - # "batch" of 1 activation (or grad_output). This is because some layers (e.g. - # nn.Flatten) do not work equivalently if they're provided with a batch or with - # an element of a batch. We thus always provide them with batches, just of a - # different size. - args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) - kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j) - grad_outputs_j_ = tuple(x.unsqueeze(0) for x in grad_outputs_j) + def __call__(self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...]) -> Tensor: + """Computes and returns the generalized Jacobian, with its parameter dimensions grouped""" - def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]: - all_state = [ - cast(dict[str, Tensor], rg_params), - dict(self.module.named_buffers()), - cast(dict[str, Tensor], self.frozen_params), - ] - output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) - flat_outputs = tree_flatten(output)[0] - rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad) - return rg_outputs + batched_jacobian = self.compute(rg_outputs, grad_outputs) - vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] + # I think that in our specific case, it's safe to use debug_unwrap, because jacobian will + # escape from the vmap context anyway. We basically combine two forbidden things (escaping a + # tensor from a vmap context, and unwrapping a BatchedTensor) that seem to be ok when + # combined. + jacobian = torch.func.debug_unwrap(batched_jacobian, recurse=True) + return jacobian - # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the - # functional has a single primal which is dict(module.named_parameters()). We therefore take - # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. - gradients = vjp_func(grad_outputs_j_)[0] - gradient = torch.cat([t.reshape(-1) for t in gradients.values()]) - return gradient + @abstractmethod + def compute(self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...]) -> Tensor: + """Computes and returns the generalized Jacobian, possibly batched.""" class AutogradJacobianComputer(JacobianComputer): @@ -124,13 +52,7 @@ class AutogradJacobianComputer(JacobianComputer): doesn't require making an extra forward pass. """ - def _compute_jacobian( - self, - rg_outputs: tuple[Tensor, ...], - grad_outputs: tuple[Tensor, ...], - _: tuple[PyTree, ...], - __: dict[str, PyTree], - ) -> Tensor: + def compute(self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...]) -> Tensor: flat_rg_params, ___ = tree_flatten(self.rg_params) grads = torch.autograd.grad( rg_outputs, @@ -141,47 +63,4 @@ def _compute_jacobian( materialize_grads=True, ) flattened_grads = torch.cat([g.reshape(-1) for g in grads]) - jacobian = flattened_grads.unsqueeze(0) - return jacobian - - -class ComputeModuleJacobians(torch.autograd.Function): - @staticmethod - def forward( - compute_jacobian_fn: Callable[ - [tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Tensor - ], - rg_outputs: tuple[Tensor, ...], - grad_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - ) -> Tensor: - # There is no non-batched dimension - jacobian = compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs) - return jacobian - - @staticmethod - def vmap( - _, - in_dims: tuple[None, None, tuple[int, ...], None, None], - compute_jacobian_fn: Callable, - rg_outputs: tuple[Tensor, ...], - jac_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - ) -> tuple[Tensor, None]: - # There is a non-batched dimension - # We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension - generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])( - rg_outputs, - jac_outputs, - args, - kwargs, - ) - shape = generalized_jacobian.shape - jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) - return jacobian, None - - @staticmethod - def setup_context(*_) -> None: - pass + return flattened_grads diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 082ace69e..717af2a7c 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -63,7 +63,7 @@ def hook_module(self, module: nn.Module, gramian_computer: GramianComputer) -> N self._gramian_accumulator, gramian_computer, ) - self._handles.append(module.register_forward_hook(hook, with_kwargs=True)) + self._handles.append(module.register_forward_hook(hook)) @staticmethod def remove_hooks(handles: list[TorchRemovableHandle]) -> None: @@ -102,13 +102,9 @@ def __init__( def __call__( self, module: nn.Module, - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], + _: tuple[PyTree, ...], outputs: PyTree, ) -> PyTree: - if self.gramian_accumulation_phase: - return outputs - flat_outputs, output_spec = tree_flatten(outputs) rg_outputs = list[Tensor]() @@ -135,8 +131,6 @@ def __call__( autograd_fn_rg_outputs = AutogramNode.apply( self.gramian_accumulation_phase, self.gramian_computer, - args, - kwargs, self.gramian_accumulator, *rg_outputs, ) @@ -159,15 +153,11 @@ class AutogramNode(torch.autograd.Function): def forward( gramian_accumulation_phase: BoolRef, gramian_computer: GramianComputer, - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, *rg_tensors: Tensor, ) -> tuple[Tensor, ...]: return tuple(t.detach() for t in rg_tensors) - # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, GramianComputer, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -176,23 +166,16 @@ def setup_context( ): ctx.gramian_accumulation_phase = inputs[0] ctx.gramian_computer = inputs[1] - ctx.args = inputs[2] - ctx.kwargs = inputs[3] - ctx.gramian_accumulator = inputs[4] - ctx.rg_outputs = inputs[5:] + ctx.gramian_accumulator = inputs[2] + ctx.rg_outputs = inputs[3:] @staticmethod def backward(ctx, *grad_outputs: Tensor) -> tuple: - # For python > 3.10: -> tuple[None, None, None, None, None, *tuple[Tensor, ...]] + # For python > 3.10: -> tuple[None, None, None, *tuple[Tensor, ...]] if ctx.gramian_accumulation_phase: - optional_gramian = ctx.gramian_computer( - ctx.rg_outputs, - grad_outputs, - ctx.args, - ctx.kwargs, - ) + optional_gramian = ctx.gramian_computer(ctx.rg_outputs, grad_outputs) if optional_gramian is not None: ctx.gramian_accumulator.accumulate_gramian(optional_gramian) - return None, None, None, None, None, *grad_outputs + return None, None, None, *grad_outputs diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 64ce48f7e..ce92b6517 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -20,7 +20,7 @@ def test_engine(): weighting = UPGradWeighting() # Create the engine before the backward pass, and only once. - engine = Engine(model, batch_dim=0) + engine = Engine(model) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 53d92ed2f..7a381d628 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -94,7 +94,7 @@ def test_iwmtl(): optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module, batch_dim=0) + engine = Engine(shared_module) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task @@ -184,7 +184,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model, batch_dim=0) + engine = Engine(model) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -374,7 +374,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:], batch_dim=0) + engine = Engine(model[2:]) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 7a6556f27..49b40d506 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -116,7 +116,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model, batch_dim=0) + engine = Engine(model) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 50135796e..f1e747fe8 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -2,7 +2,6 @@ from itertools import combinations from math import prod -import pytest import torch from pytest import mark, param from torch import Tensor @@ -83,7 +82,7 @@ from torchjd.autogram._engine import Engine from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian -PARAMETRIZATIONS = [ +BASE_PARAMETRIZATIONS = [ (ModuleFactory(OverlyNested), 32), (ModuleFactory(MultiInputSingleOutput), 32), (ModuleFactory(MultiInputMultiOutput), 32), @@ -128,6 +127,9 @@ ), (ModuleFactory(FreeParam), 32), (ModuleFactory(NoFreeParam), 32), + (ModuleFactory(Randomness), 32), + (ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 32), + (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 32), param(ModuleFactory(Cifar10Model), 16, marks=mark.slow), param(ModuleFactory(AlexNet), 2, marks=mark.slow), param(ModuleFactory(InstanceNormResNet18), 4, marks=mark.slow), @@ -141,12 +143,19 @@ ), ] +# These parametrizations are expected to fail on test_autograd_while_modules_are_hooked +_SPECIAL_PARAMETRIZATIONS = [ + (ModuleFactory(WithSideEffect), 32), # When use_engine=True, double side-effect + param(ModuleFactory(WithRNN), 32, marks=mark.xfail_if_cuda), # Does not fail on cuda + # when use_engine=False because engine is not even used. +] -def _assert_gramian_is_equivalent_to_autograd( - factory: ModuleFactory, batch_size: int, batch_dim: int | None -): +PARAMETRIZATIONS = BASE_PARAMETRIZATIONS + _SPECIAL_PARAMETRIZATIONS + + +def _assert_gramian_is_equivalent_to_autograd(factory: ModuleFactory, batch_size: int): model_autograd, model_autogram = factory(), factory() - engine = Engine(model_autogram, batch_dim=batch_dim) + engine = Engine(model_autogram) inputs, targets = make_inputs_and_targets(model_autograd, batch_size) loss_fn = make_mse_loss_fn(targets) @@ -187,34 +196,10 @@ def _get_losses_and_params_without_cross_terms( @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) -@mark.parametrize("batch_dim", [0, None]) -def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int | None): +def test_compute_gramian(factory: ModuleFactory, batch_size: int): """Tests that the autograd and the autogram engines compute the same gramian.""" - _assert_gramian_is_equivalent_to_autograd(factory, batch_size, batch_dim) - - -@mark.parametrize( - "factory", - [ - ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), - ModuleFactory(WithSideEffect), - ModuleFactory(Randomness), - ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), - param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda), - ], -) -@mark.parametrize("batch_size", [1, 3, 32]) -@mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None]) -def test_compute_gramian_with_weird_modules( - factory: ModuleFactory, batch_size: int, batch_dim: int | None -): - """ - Tests that compute_gramian works even with some problematic modules when batch_dim is None. It - is expected to fail on those when the engine uses the batched optimization (when batch_dim=0). - """ - - _assert_gramian_is_equivalent_to_autograd(factory, batch_size, batch_dim) + _assert_gramian_is_equivalent_to_autograd(factory, batch_size) @mark.xfail @@ -227,45 +212,35 @@ def test_compute_gramian_with_weird_modules( ], ) @mark.parametrize("batch_size", [1, 3, 32]) -@mark.parametrize("batch_dim", [0, None]) -def test_compute_gramian_unsupported_architectures( - factory: ModuleFactory, batch_size: int, batch_dim: int | None -): +def test_compute_gramian_unsupported_architectures(factory: ModuleFactory, batch_size: int): """ Tests compute_gramian on some architectures that are known to be unsupported. It is expected to fail. """ - _assert_gramian_is_equivalent_to_autograd(factory, batch_size, batch_dim) + _assert_gramian_is_equivalent_to_autograd(factory, batch_size) @mark.parametrize("batch_size", [1, 3, 16]) @mark.parametrize( - ["reduction", "movedim_source", "movedim_destination", "batch_dim"], + ["reduction", "movedim_source", "movedim_destination"], [ # 0D - (reduce_to_scalar, [], [], None), # () + (reduce_to_scalar, [], []), # () # 1D - (reduce_to_vector, [], [], 0), # (batch_size,) - (reduce_to_vector, [], [], None), # (batch_size,) + (reduce_to_vector, [], []), # (batch_size,) # 2D - (reduce_to_matrix, [], [], 0), # (batch_size, d1 * d2) - (reduce_to_matrix, [], [], None), # (batch_size, d1 * d2) - (reduce_to_matrix, [0], [1], 1), # (d1 * d2, batch_size) - (reduce_to_matrix, [0], [1], None), # (d1 * d2, batch_size) + (reduce_to_matrix, [], []), # (batch_size, d1 * d2) + (reduce_to_matrix, [0], [1]), # (d1 * d2, batch_size) # 3D - (reduce_to_first_tensor, [], [], 0), # (batch_size, d1, d2) - (reduce_to_first_tensor, [], [], None), # (batch_size, d1, d2) - (reduce_to_first_tensor, [0], [1], 1), # (d1, batch_size, d2) - (reduce_to_first_tensor, [0], [1], None), # (d1, batch_size, d2) - (reduce_to_first_tensor, [0], [2], 2), # (d2, d1, batch_size) - (reduce_to_first_tensor, [0], [2], None), # (d2, d1, batch_size) + (reduce_to_first_tensor, [], []), # (batch_size, d1, d2) + (reduce_to_first_tensor, [0], [1]), # (d1, batch_size, d2) + (reduce_to_first_tensor, [0], [2]), # (d2, d1, batch_size) ], ) def test_compute_gramian_various_output_shapes( batch_size: int | None, reduction: Callable[[list[Tensor]], Tensor], - batch_dim: int | None, movedim_source: list[int], movedim_destination: list[int], ): @@ -286,7 +261,7 @@ def test_compute_gramian_various_output_shapes( autograd_gramian = compute_gramian_with_autograd(loss_vector, params) expected_gramian = reshape_gramian(autograd_gramian, list(reshaped_losses.shape)) - engine = Engine(model_autogram, batch_dim=batch_dim) + engine = Engine(model_autogram) losses = forward_pass(model_autogram, inputs, loss_fn, reduction) reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination) autogram_gramian = engine.compute_gramian(reshaped_losses) @@ -302,8 +277,7 @@ def _non_empty_subsets(elements: set) -> list[set]: @mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"})) -@mark.parametrize("batch_dim", [0, None]) -def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int | None): +def test_compute_partial_gramian(gramian_module_names: set[str]): """ Tests that the autograd and the autogram engines compute the same gramian when only a subset of the model parameters is specified. @@ -322,7 +296,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int losses = forward_pass(model, inputs, loss_fn, reduce_to_vector) autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) - engine = Engine(*gramian_modules, batch_dim=batch_dim) + engine = Engine(*gramian_modules) losses = forward_pass(model, inputs, loss_fn, reduce_to_vector) gramian = engine.compute_gramian(losses) @@ -330,14 +304,13 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) -@mark.parametrize("batch_dim", [0, None]) -def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch_dim: int | None): +def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int): """Tests that the autogram engine doesn't raise any error during several IWRM iterations.""" n_iter = 3 model = factory() weighting = UPGradWeighting() - engine = Engine(model, batch_dim=batch_dim) + engine = Engine(model) optimizer = SGD(model.parameters(), lr=1e-7) for i in range(n_iter): @@ -348,11 +321,12 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch model.zero_grad() -@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) +@mark.parametrize(["factory", "batch_size"], BASE_PARAMETRIZATIONS) @mark.parametrize("use_engine", [False, True]) -@mark.parametrize("batch_dim", [0, None]) def test_autograd_while_modules_are_hooked( - factory: ModuleFactory, batch_size: int, use_engine: bool, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + use_engine: bool, ): """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd @@ -367,11 +341,11 @@ def test_autograd_while_modules_are_hooked( autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} # Hook modules and optionally compute the Gramian - engine = Engine(model_autogram, batch_dim=batch_dim) + engine = Engine(model_autogram) + if use_engine: losses = forward_pass(model_autogram, inputs, loss_fn, reduce_to_vector) _ = engine.compute_gramian(losses) - # Verify that even with the hooked modules, autograd works normally when not using the engine. # Results should be the same as a normal call to autograd, and no time should be spent computing # the gramian at all. @@ -382,22 +356,6 @@ def test_autograd_while_modules_are_hooked( assert engine._gramian_accumulator.gramian is None -@mark.parametrize( - ["factory", "batch_dim"], - [ - (ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0), - param(ModuleFactory(WithRNN), 0), - (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0), - ], -) -def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None): - """Tests that the engine cannot be constructed with incompatible modules.""" - - model = factory() - with pytest.raises(ValueError): - _ = Engine(model, batch_dim=batch_dim) - - def test_compute_gramian_manual(): """ Tests that the Gramian computed by the `Engine` equals to a manual computation of the expected @@ -410,7 +368,7 @@ def test_compute_gramian_manual(): model = factory() input = randn_(in_dims) - engine = Engine(model, batch_dim=None) + engine = Engine(model) output = model(input) gramian = engine.compute_gramian(output) @@ -455,12 +413,12 @@ def test_reshape_equivariance(shape: list[int]): model1, model2 = factory(), factory() input = randn_([input_size]) - engine1 = Engine(model1, batch_dim=None) + engine1 = Engine(model1) output = model1(input) gramian = engine1.compute_gramian(output) expected_reshaped_gramian = reshape_gramian(gramian, shape[1:]) - engine2 = Engine(model2, batch_dim=None) + engine2 = Engine(model2) reshaped_output = model2(input).reshape(shape[1:]) reshaped_gramian = engine2.compute_gramian(reshaped_output) @@ -493,78 +451,13 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: model1, model2 = factory(), factory() input = randn_([input_size]) - engine1 = Engine(model1, batch_dim=None) + engine1 = Engine(model1) output = model1(input).reshape(shape[1:]) gramian = engine1.compute_gramian(output) expected_moved_gramian = movedim_gramian(gramian, source, destination) - engine2 = Engine(model2, batch_dim=None) + engine2 = Engine(model2) moved_output = model2(input).reshape(shape[1:]).movedim(source, destination) moved_gramian = engine2.compute_gramian(moved_output) assert_close(moved_gramian, expected_moved_gramian) - - -@mark.parametrize( - ["shape", "batch_dim"], - [ - ([2, 5, 3, 2], 2), - ([3, 2, 5], 1), - ([6, 3], 0), - ([4, 3, 2], 1), - ([1, 1, 1], 0), - ([1, 1, 1], 1), - ([1, 1, 1], 2), - ([1, 1], 0), - ([1], 0), - ([4, 3, 1], 2), - ], -) -def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): - """ - Tests that for a vector with some batched dimensions, the gramian is the same if we use the - appropriate `batch_dim` or if we don't use any. - """ - - non_batched_shape = [shape[i] for i in range(len(shape)) if i != batch_dim] - input_size = prod(non_batched_shape) - batch_size = shape[batch_dim] - output_size = input_size - factory = ModuleFactory(Linear, input_size, output_size) - model1, model2 = factory(), factory() - input = randn_([batch_size, input_size]) - - engine1 = Engine(model1, batch_dim=batch_dim) - output1 = model1(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim) - gramian1 = engine1.compute_gramian(output1) - - engine2 = Engine(model2, batch_dim=None) - output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim) - gramian2 = engine2.compute_gramian(output2) - - assert_close(gramian1, gramian2) - - -@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) -def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: int): - """ - Same as test_batched_non_batched_equivalence but on real architectures, and thus only between - batch_size=0 and batch_size=None. - - If for some architecture this test passes but the test_compute_gramian doesn't pass, it could be - that the get_used_params does not work for some module of the architecture. - """ - - model_0, model_none = factory(), factory() - inputs, targets = make_inputs_and_targets(model_0, batch_size) - loss_fn = make_mse_loss_fn(targets) - - engine_0 = Engine(model_0, batch_dim=0) - losses_0 = forward_pass(model_0, inputs, loss_fn, reduce_to_vector) - gramian_0 = engine_0.compute_gramian(losses_0) - - engine_none = Engine(model_none, batch_dim=None) - losses_none = forward_pass(model_none, inputs, loss_fn, reduce_to_vector) - gramian_none = engine_none.compute_gramian(losses_none) - - assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5)