Skip to content

Commit 170e7c5

Browse files
authored
refactor(autogram): Rework VJP class hierarchy (#433)
* Split VJP base class into VJP (for the __call__ prototype) and ModuleVJP (for referencing a module, its trainable and frozen parameters). * Add Vmapped to wrap a VJP. * Remove VJPType (now useless because a vmapped VJP will also be of type VJP).
1 parent f3d0701 commit 170e7c5

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ._edge_registry import EdgeRegistry
1111
from ._gramian_accumulator import GramianAccumulator
12-
from ._vjp import AutogradVJP, FunctionalVJP, VJPType
12+
from ._vjp import VJP, AutogradVJP, FunctionalVJP, Vmapped
1313

1414
# Note about import from protected _pytree module:
1515
# PyTorch maintainers plan to make pytree public (see
@@ -93,7 +93,7 @@ class AccumulateJacobian(torch.autograd.Function):
9393
@staticmethod
9494
def forward(
9595
output_spec: TreeSpec,
96-
vjp: VJPType,
96+
vjp: VJP,
9797
args: PyTree,
9898
gramian_accumulator: GramianAccumulator,
9999
module: nn.Module,
@@ -110,7 +110,7 @@ def vmap(
110110
_,
111111
in_dims: PyTree,
112112
output_spec: TreeSpec,
113-
vjp: VJPType,
113+
vjp: VJP,
114114
args: PyTree,
115115
gramian_accumulator: GramianAccumulator,
116116
module: nn.Module,
@@ -157,7 +157,7 @@ class JacobianAccumulator(torch.autograd.Function):
157157
def forward(
158158
gramian_accumulation_phase: BoolRef,
159159
output_spec: TreeSpec,
160-
vjp: VJPType,
160+
vjp: VJP,
161161
args: PyTree,
162162
gramian_accumulator: GramianAccumulator,
163163
module: nn.Module,
@@ -166,7 +166,7 @@ def forward(
166166
return tuple(x.detach() for x in xs)
167167

168168
# For Python version > 3.10, the type of `inputs` should become
169-
# tuple[BoolRef, TreeSpec, VJPType, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
169+
# tuple[BoolRef, TreeSpec, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
170170
@staticmethod
171171
def setup_context(
172172
ctx,
@@ -232,8 +232,9 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
232232
index = cast(int, preference.argmin().item())
233233
self.target_edges.register(get_gradient_edge(flat_outputs[index]))
234234

235+
vjp: VJP
235236
if self.has_batch_dim:
236-
vjp = torch.vmap(FunctionalVJP(module))
237+
vjp = Vmapped(FunctionalVJP(module))
237238
else:
238239
vjp = AutogradVJP(module, flat_outputs)
239240

src/torchjd/autogram/_vjp.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,18 @@
1515
# still support older versions of PyTorch where pytree is protected).
1616

1717

18-
# This includes vmapped VJPs, which are not of type VJP.
19-
VJPType = Callable[[PyTree, PyTree], dict[str, Tensor]]
18+
class VJP(ABC):
19+
"""Represents an abstract VJP function."""
2020

21+
@abstractmethod
22+
def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
23+
"""
24+
Computes and returns the dictionary of parameter names to their gradients for the given
25+
grad_outputs (cotangents) and at the given inputs.
26+
"""
2127

22-
class VJP(ABC):
28+
29+
class ModuleVJP(VJP, ABC):
2330
"""
2431
Represents an abstract VJP function for a module's forward pass with respect to its parameters.
2532
@@ -37,15 +44,19 @@ def __init__(self, module: nn.Module):
3744
else:
3845
self.frozen_params[name] = param
3946

40-
@abstractmethod
47+
48+
class Vmapped(VJP):
49+
"""VJP wrapper that applies the wrapped VJP, vmapped on the first dimension."""
50+
51+
def __init__(self, vjp: VJP):
52+
super().__init__()
53+
self.vmapped_vjp = torch.vmap(vjp)
54+
4155
def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
42-
"""
43-
Computes and returns the dictionary of parameter names to their gradients for the given
44-
grad_outputs (cotangents) and at the given inputs.
45-
"""
56+
return self.vmapped_vjp(grad_outputs, inputs)
4657

4758

48-
class FunctionalVJP(VJP):
59+
class FunctionalVJP(ModuleVJP):
4960
"""
5061
Represents a VJP function for a module's forward pass with respect to its parameters using the
5162
func api. The __call__ function takes both the inputs and the cotangents that can be vmapped
@@ -95,7 +106,7 @@ def functional_model_call(primals: dict[str, Parameter]) -> Tensor:
95106
return torch.func.vjp(functional_model_call, self.trainable_params)[1]
96107

97108

98-
class AutogradVJP(VJP):
109+
class AutogradVJP(ModuleVJP):
99110
"""
100111
Represents a VJP function for a module's forward pass with respect to its parameters using the
101112
autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the

0 commit comments

Comments
 (0)