Skip to content

Commit e44f434

Browse files
authored
refactor(autogram): Merge Vmapped to FunctionalVJP (#434)
1 parent 170e7c5 commit e44f434

File tree

2 files changed

+17
-43
lines changed

2 files changed

+17
-43
lines changed

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ._edge_registry import EdgeRegistry
1111
from ._gramian_accumulator import GramianAccumulator
12-
from ._vjp import VJP, AutogradVJP, FunctionalVJP, Vmapped
12+
from ._vjp import VJP, AutogradVJP, FunctionalVJP
1313

1414
# Note about import from protected _pytree module:
1515
# PyTorch maintainers plan to make pytree public (see
@@ -232,11 +232,7 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
232232
index = cast(int, preference.argmin().item())
233233
self.target_edges.register(get_gradient_edge(flat_outputs[index]))
234234

235-
vjp: VJP
236-
if self.has_batch_dim:
237-
vjp = Vmapped(FunctionalVJP(module))
238-
else:
239-
vjp = AutogradVJP(module, flat_outputs)
235+
vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs)
240236

241237
autograd_fn_outputs = JacobianAccumulator.apply(
242238
self.gramian_accumulation_phase,

src/torchjd/autogram/_vjp.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from collections.abc import Callable, Sequence
2+
from collections.abc import Sequence
33

44
import torch
55
from torch import Tensor, nn
@@ -45,31 +45,21 @@ def __init__(self, module: nn.Module):
4545
self.frozen_params[name] = param
4646

4747

48-
class Vmapped(VJP):
49-
"""VJP wrapper that applies the wrapped VJP, vmapped on the first dimension."""
50-
51-
def __init__(self, vjp: VJP):
52-
super().__init__()
53-
self.vmapped_vjp = torch.vmap(vjp)
54-
55-
def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
56-
return self.vmapped_vjp(grad_outputs, inputs)
57-
58-
5948
class FunctionalVJP(ModuleVJP):
6049
"""
6150
Represents a VJP function for a module's forward pass with respect to its parameters using the
62-
func api. The __call__ function takes both the inputs and the cotangents that can be vmapped
63-
jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using
64-
this method is that it makes an extra forward pass.
65-
66-
:params module: The module to differentiate.
51+
functional differentiation API. This requires to use vmap, so it's not compatible with
52+
every module, and it requires to have an extra forward pass to create the vjp function.
6753
"""
6854

6955
def __init__(self, module: nn.Module):
7056
super().__init__(module)
57+
self.vmapped_vjp = torch.vmap(self._call_on_one_instance)
58+
59+
def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
60+
return self.vmapped_vjp(grad_outputs, inputs)
7161

72-
def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]:
62+
def _call_on_one_instance(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]:
7363
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
7464
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
7565
# nn.Flatten) do not work equivalently if they're provided with a batch or with
@@ -78,32 +68,20 @@ def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor
7868
inputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), inputs_j)
7969
grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j)
8070

81-
# _vjp_from_module returns a function that computes the vjp w.r.t. to the
82-
# primals (tuple), here the functional has a single primal which is
83-
# dict(module.named_parameters()). We therefore take the 0'th element to obtain
84-
# the dict of gradients w.r.t. the module's named_parameters.
85-
return self._vjp_from_module(inputs_j)(grad_outputs_j)[0]
86-
87-
def _vjp_from_module(self, inputs: PyTree) -> Callable[[PyTree], tuple[dict[str, Tensor]]]:
88-
"""
89-
Create a VJP function for a module's forward pass with respect to its parameters.
90-
91-
Returns a function that computes vector-Jacobian products for the module's parameters given
92-
fixed inputs. Only parameters with requires_grad=True are included in the differentiation.
93-
94-
:param inputs: Fixed inputs to the module for the VJP computation.
95-
:returns: VJP function that takes cotangents and returns parameter gradients.
96-
"""
97-
9871
def functional_model_call(primals: dict[str, Parameter]) -> Tensor:
9972
all_state = {
10073
**primals,
10174
**dict(self.module.named_buffers()),
10275
**self.frozen_params,
10376
}
104-
return torch.func.functional_call(self.module, all_state, inputs)
77+
return torch.func.functional_call(self.module, all_state, inputs_j)
78+
79+
vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]
10580

106-
return torch.func.vjp(functional_model_call, self.trainable_params)[1]
81+
# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
82+
# functional has a single primal which is dict(module.named_parameters()). We therefore take
83+
# the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
84+
return vjp_func(grad_outputs_j)[0]
10785

10886

10987
class AutogradVJP(ModuleVJP):

0 commit comments

Comments
 (0)