Skip to content

Commit 5595f5d

Browse files
refactor(autogram): Use GramianComputers working on modules (#453)
* Add GramianComputer, JacobianBasedGramianComputer and JacobianBasedGramianComputerWithCrossTerms * Rename VJP classes to JacobianComputer classes * Move ComputeModuleJacobians to JacobianComputer * Make GramianAccumulator only accumulate gramians * Use GramianComputer instead of JacobianComputer and create them outside of hook * Uniformize some parameter ordering and types * Remove test_gramian_accumulator * Make InterModuleParamReuse xfail
1 parent 28cce93 commit 5595f5d

File tree

8 files changed

+323
-372
lines changed

8 files changed

+323
-372
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
from ._edge_registry import EdgeRegistry
88
from ._gramian_accumulator import GramianAccumulator
9+
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
910
from ._gramian_utils import movedim_gramian, reshape_gramian
11+
from ._jacobian_computer import (
12+
AutogradJacobianComputer,
13+
FunctionalJacobianComputer,
14+
JacobianComputer,
15+
)
1016
from ._module_hook_manager import ModuleHookManager
1117

1218
_MODULES_INCOMPATIBLE_WITH_BATCHED = (
@@ -179,21 +185,32 @@ def __init__(
179185
self._gramian_accumulator = GramianAccumulator()
180186
self._target_edges = EdgeRegistry()
181187
self._batch_dim = batch_dim
182-
self._module_hook_manager = ModuleHookManager(
183-
self._target_edges, self._gramian_accumulator, batch_dim is not None
184-
)
188+
self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator)
189+
self._gramian_computers = dict[nn.Module, GramianComputer]()
185190

186191
for module in modules:
187192
self._hook_module_recursively(module)
188193

189194
def _hook_module_recursively(self, module: nn.Module) -> None:
190195
if any(p.requires_grad for p in module.parameters(recurse=False)):
191196
self._check_module_is_compatible(module)
192-
self._module_hook_manager.hook_module(module)
197+
gramian_computer = self._make_gramian_computer(module)
198+
self._gramian_computers[module] = gramian_computer
199+
self._module_hook_manager.hook_module(module, gramian_computer)
193200
else:
194201
for child in module.children():
195202
self._hook_module_recursively(child)
196203

204+
def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
205+
jacobian_computer: JacobianComputer
206+
if self._batch_dim is not None:
207+
jacobian_computer = FunctionalJacobianComputer(module)
208+
else:
209+
jacobian_computer = AutogradJacobianComputer(module)
210+
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
211+
212+
return gramian_computer
213+
197214
def _check_module_is_compatible(self, module: nn.Module) -> None:
198215
if self._batch_dim is not None:
199216
if isinstance(module, _MODULES_INCOMPATIBLE_WITH_BATCHED):
@@ -276,6 +293,8 @@ def compute_gramian(self, output: Tensor) -> Tensor:
276293
self._module_hook_manager.gramian_accumulation_phase.value = False
277294
self._gramian_accumulator.reset()
278295
self._target_edges.reset()
296+
for gramian_computer in self._gramian_computers.values():
297+
gramian_computer.reset()
279298

280299
unordered_gramian = reshape_gramian(square_gramian, ordered_shape)
281300

src/torchjd/autogram/_gramian_accumulator.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from collections import Counter
2-
from collections.abc import Iterable
31
from typing import Optional
42

5-
import torch
63
from torch import Tensor
74

85

@@ -17,60 +14,15 @@ class GramianAccumulator:
1714

1815
def __init__(self) -> None:
1916
self._gramian: Optional[Tensor] = None
20-
self._summed_jacobians = dict[Tensor, Tensor]()
21-
self._path_counter = Counter[Tensor]()
2217

2318
def reset(self) -> None:
2419
self._gramian = None
25-
self._summed_jacobians = {}
26-
self._path_counter = Counter()
2720

28-
def track_parameter_paths(self, parameters: Iterable[Tensor]) -> None:
29-
"""
30-
Register parameters and count their paths in the computational graph.
31-
32-
:param parameters: Parameter tensors to track. Duplicates increase path count.
33-
"""
34-
self._path_counter.update(parameters)
35-
36-
def accumulate_path_jacobians(self, path_jacobians: dict[Tensor, Tensor]) -> None:
37-
"""
38-
Add path Jacobians for multiple parameters.
39-
40-
:param path_jacobians: Dictionary mapping parameters to Jacobian tensors of a single path.
41-
"""
42-
for parameter, jacobian in path_jacobians.items():
43-
self._accumulate_path_jacobian(parameter, jacobian)
44-
45-
def _accumulate_path_jacobian(self, parameter: Tensor, jacobian: Tensor) -> None:
46-
"""
47-
Add path Jacobian for a parameter. In case the full Jacobian is computed, accumulate its
48-
Gramian.
49-
50-
:param parameter: The parameter.
51-
:param jacobian: path Jacobian with respect to the parameter.
52-
"""
53-
if parameter in self._summed_jacobians:
54-
self._summed_jacobians[parameter] += jacobian
55-
else:
56-
self._summed_jacobians[parameter] = jacobian
57-
self._path_counter.subtract([parameter])
58-
if self._path_counter[parameter] == 0:
59-
self._accumulate_gramian(parameter)
60-
del self._path_counter[parameter]
61-
del self._summed_jacobians[parameter]
62-
63-
def _accumulate_gramian(self, parameter: Tensor) -> None:
64-
"""
65-
Compute the Gramian of the full Jacobian and accumulate it.
66-
67-
:param parameter: Parameter whose full Jacobian is available.
68-
"""
69-
full_jacobian_matrix = torch.flatten(self._summed_jacobians[parameter], start_dim=1)
21+
def accumulate_gramian(self, gramian: Tensor) -> None:
7022
if self._gramian is not None:
71-
self._gramian.addmm_(full_jacobian_matrix, full_jacobian_matrix.T)
23+
self._gramian.add_(gramian)
7224
else:
73-
self._gramian = torch.mm(full_jacobian_matrix, full_jacobian_matrix.T)
25+
self._gramian = gramian
7426

7527
@property
7628
def gramian(self) -> Optional[Tensor]:
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional
3+
4+
from torch import Tensor
5+
from torch.utils._pytree import PyTree
6+
7+
from torchjd.autogram._jacobian_computer import JacobianComputer
8+
9+
10+
class GramianComputer(ABC):
11+
@abstractmethod
12+
def __call__(
13+
self,
14+
rg_outputs: tuple[Tensor, ...],
15+
grad_outputs: tuple[Tensor, ...],
16+
args: tuple[PyTree, ...],
17+
kwargs: dict[str, PyTree],
18+
) -> Optional[Tensor]:
19+
"""Compute what we can for a module and optionally return the gramian if it's ready."""
20+
21+
def track_forward_call(self) -> None:
22+
"""Track that the module's forward was called. Necessary in some implementations."""
23+
24+
def reset(self):
25+
"""Reset state if any. Necessary in some implementations."""
26+
27+
28+
class JacobianBasedGramianComputer(GramianComputer, ABC):
29+
def __init__(self, jacobian_computer):
30+
self.jacobian_computer = jacobian_computer
31+
32+
@staticmethod
33+
def _to_gramian(jacobian: Tensor) -> Tensor:
34+
return jacobian @ jacobian.T
35+
36+
37+
class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
38+
"""
39+
Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning
40+
the gramian.
41+
"""
42+
43+
def __init__(self, jacobian_computer: JacobianComputer):
44+
super().__init__(jacobian_computer)
45+
self.remaining_counter = 0
46+
self.summed_jacobian: Optional[Tensor] = None
47+
48+
def reset(self) -> None:
49+
self.remaining_counter = 0
50+
self.summed_jacobian = None
51+
52+
def track_forward_call(self) -> None:
53+
self.remaining_counter += 1
54+
55+
def __call__(
56+
self,
57+
rg_outputs: tuple[Tensor, ...],
58+
grad_outputs: tuple[Tensor, ...],
59+
args: tuple[PyTree, ...],
60+
kwargs: dict[str, PyTree],
61+
) -> Optional[Tensor]:
62+
"""Compute what we can for a module and optionally return the gramian if it's ready."""
63+
64+
jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
65+
66+
if self.summed_jacobian is None:
67+
self.summed_jacobian = jacobian_matrix
68+
else:
69+
self.summed_jacobian += jacobian_matrix
70+
71+
self.remaining_counter -= 1
72+
73+
if self.remaining_counter == 0:
74+
gramian = self._to_gramian(self.summed_jacobian)
75+
del self.summed_jacobian
76+
return gramian
77+
else:
78+
return None
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
from abc import ABC, abstractmethod
2+
from collections.abc import Callable
3+
from typing import cast
4+
5+
import torch
6+
from torch import Tensor, nn
7+
from torch.nn import Parameter
8+
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
9+
10+
# Note about import from protected _pytree module:
11+
# PyTorch maintainers plan to make pytree public (see
12+
# https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400).
13+
# It should also come with better speed, because the current implementation is slow, according to
14+
# https://github.com/pytorch/pytorch/issues/65761#issue-1010116111.
15+
# When pytree becomes public, this import will have to be changed with a conditional import (to
16+
# still support older versions of PyTorch where pytree is protected).
17+
18+
19+
class JacobianComputer(ABC):
20+
"""
21+
Abstract class to computes Jacobians for a module's forward pass with respect to its parameters.
22+
23+
:params module: The module to differentiate.
24+
"""
25+
26+
def __init__(self, module: nn.Module):
27+
self.module = module
28+
29+
self.rg_params = dict[str, Parameter]()
30+
self.frozen_params = dict[str, Parameter]()
31+
32+
for name, param in module.named_parameters(recurse=True):
33+
if param.requires_grad:
34+
self.rg_params[name] = param
35+
else:
36+
self.frozen_params[name] = param
37+
38+
def __call__(
39+
self,
40+
rg_outputs: tuple[Tensor, ...],
41+
grad_outputs: tuple[Tensor, ...],
42+
args: tuple[PyTree, ...],
43+
kwargs: dict[str, PyTree],
44+
) -> Tensor:
45+
# This makes __call__ vmappable.
46+
return ComputeModuleJacobians.apply(
47+
self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs
48+
)
49+
50+
@abstractmethod
51+
def _compute_jacobian(
52+
self,
53+
rg_outputs: tuple[Tensor, ...],
54+
grad_outputs: tuple[Tensor, ...],
55+
args: tuple[PyTree, ...],
56+
kwargs: dict[str, PyTree],
57+
) -> Tensor:
58+
"""
59+
Computes and returns the Jacobian. The output must be a matrix (2D Tensor).
60+
"""
61+
62+
63+
class FunctionalJacobianComputer(JacobianComputer):
64+
"""
65+
JacobianComputer using the functional differentiation API. This requires to use vmap, so it's
66+
not compatible with every module, and it requires to have an extra forward pass to create the
67+
vjp function.
68+
"""
69+
70+
def _compute_jacobian(
71+
self,
72+
_: tuple[Tensor, ...],
73+
grad_outputs: tuple[Tensor, ...],
74+
args: tuple[PyTree, ...],
75+
kwargs: dict[str, PyTree],
76+
) -> Tensor:
77+
grad_outputs_in_dims = (0,) * len(grad_outputs)
78+
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
79+
kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs)
80+
in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims)
81+
vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)
82+
83+
return vmapped_vjp(grad_outputs, args, kwargs)
84+
85+
def _call_on_one_instance(
86+
self,
87+
grad_outputs_j: tuple[Tensor, ...],
88+
args_j: tuple[PyTree, ...],
89+
kwargs_j: dict[str, PyTree],
90+
) -> Tensor:
91+
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
92+
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
93+
# nn.Flatten) do not work equivalently if they're provided with a batch or with
94+
# an element of a batch. We thus always provide them with batches, just of a
95+
# different size.
96+
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
97+
kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j)
98+
grad_outputs_j_ = tuple(x.unsqueeze(0) for x in grad_outputs_j)
99+
100+
def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]:
101+
all_state = [
102+
cast(dict[str, Tensor], rg_params),
103+
dict(self.module.named_buffers()),
104+
cast(dict[str, Tensor], self.frozen_params),
105+
]
106+
output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j)
107+
flat_outputs = tree_flatten(output)[0]
108+
rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad)
109+
return rg_outputs
110+
111+
vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1]
112+
113+
# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
114+
# functional has a single primal which is dict(module.named_parameters()). We therefore take
115+
# the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
116+
gradients = vjp_func(grad_outputs_j_)[0]
117+
gradient = torch.cat([t.reshape(-1) for t in gradients.values()])
118+
return gradient
119+
120+
121+
class AutogradJacobianComputer(JacobianComputer):
122+
"""
123+
JacobianComputer using the autograd engine. The main advantage of using this method is that it
124+
doesn't require making an extra forward pass.
125+
"""
126+
127+
def _compute_jacobian(
128+
self,
129+
rg_outputs: tuple[Tensor, ...],
130+
grad_outputs: tuple[Tensor, ...],
131+
_: tuple[PyTree, ...],
132+
__: dict[str, PyTree],
133+
) -> Tensor:
134+
flat_rg_params, ___ = tree_flatten(self.rg_params)
135+
grads = torch.autograd.grad(
136+
rg_outputs,
137+
flat_rg_params,
138+
grad_outputs,
139+
retain_graph=True,
140+
allow_unused=True,
141+
materialize_grads=True,
142+
)
143+
flattened_grads = torch.cat([g.reshape(-1) for g in grads])
144+
jacobian = flattened_grads.unsqueeze(0)
145+
return jacobian
146+
147+
148+
class ComputeModuleJacobians(torch.autograd.Function):
149+
@staticmethod
150+
def forward(
151+
compute_jacobian_fn: Callable[
152+
[tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Tensor
153+
],
154+
rg_outputs: tuple[Tensor, ...],
155+
grad_outputs: tuple[Tensor, ...],
156+
args: tuple[PyTree, ...],
157+
kwargs: dict[str, PyTree],
158+
) -> Tensor:
159+
# There is no non-batched dimension
160+
jacobian = compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs)
161+
return jacobian
162+
163+
@staticmethod
164+
def vmap(
165+
_,
166+
in_dims: tuple[None, None, tuple[int, ...], None, None],
167+
compute_jacobian_fn: Callable,
168+
rg_outputs: tuple[Tensor, ...],
169+
jac_outputs: tuple[Tensor, ...],
170+
args: tuple[PyTree, ...],
171+
kwargs: dict[str, PyTree],
172+
) -> tuple[Tensor, None]:
173+
# There is a non-batched dimension
174+
# We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension
175+
generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])(
176+
rg_outputs,
177+
jac_outputs,
178+
args,
179+
kwargs,
180+
)
181+
shape = generalized_jacobian.shape
182+
jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1])
183+
return jacobian, None
184+
185+
@staticmethod
186+
def setup_context(*_) -> None:
187+
pass

0 commit comments

Comments
 (0)