Skip to content

Commit 87b4da0

Browse files
committed
Simplify engine:
* Remove FunctionalJacobianComputer * Remove args and kwargs from interface of JacobianComputer, GramianComputer and JacobianAccumulator because they were only needed for the functional interface * Remove kwargs from interface of Hook and stop registering it with with_kwargs=True (args are mandatory though, so rename them as _). * Change JacobianComputer to compute generalized jacobians (shape [m0, ..., mk, n]) and change GramianComputer to compute optional generalized gramians (shape [m0, ..., mk, mk, ..., m0]) * Change engine.compute_gramian to always simply do one vmap level per dimension of the output, without caring about the batch_dim. * Remove all reshapes and movedims in engine.compute_gramian: we don't need reshape anymore since the gramian is directly a generalized gramian, and we dont need movedim anymore since we vmap on all dimensions the same way, without having to put the non-batched dim in front. Merge compute_gramian and _compute_square_gramian. * Use a DiagonalSparseTensor as initial jac_output of compute_gramian.
1 parent efa8019 commit 87b4da0

File tree

4 files changed

+50
-242
lines changed

4 files changed

+50
-242
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 24 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,9 @@
77
from ._edge_registry import EdgeRegistry
88
from ._gramian_accumulator import GramianAccumulator
99
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
10-
from ._gramian_utils import movedim_gramian, reshape_gramian
11-
from ._jacobian_computer import (
12-
AutogradJacobianComputer,
13-
FunctionalJacobianComputer,
14-
JacobianComputer,
15-
)
10+
from ._jacobian_computer import AutogradJacobianComputer
1611
from ._module_hook_manager import ModuleHookManager
12+
from .diagonal_sparse_tensor import DiagonalSparseTensor
1713

1814
_MODULES_INCOMPATIBLE_WITH_BATCHED = (
1915
nn.BatchNorm1d,
@@ -202,11 +198,7 @@ def _hook_module_recursively(self, module: nn.Module) -> None:
202198
self._hook_module_recursively(child)
203199

204200
def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
205-
jacobian_computer: JacobianComputer
206-
if self._batch_dim is not None:
207-
jacobian_computer = FunctionalJacobianComputer(module)
208-
else:
209-
jacobian_computer = AutogradJacobianComputer(module)
201+
jacobian_computer = AutogradJacobianComputer(module)
210202
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
211203

212204
return gramian_computer
@@ -261,33 +253,31 @@ def compute_gramian(self, output: Tensor) -> Tensor:
261253
- etc.
262254
"""
263255

264-
if self._batch_dim is not None:
265-
# move batched dim to the end
266-
ordered_output = output.movedim(self._batch_dim, -1)
267-
ordered_shape = list(ordered_output.shape)
268-
batch_size = ordered_shape[-1]
269-
has_non_batch_dim = len(ordered_shape) > 1
270-
target_shape = [batch_size]
271-
else:
272-
ordered_output = output
273-
ordered_shape = list(ordered_output.shape)
274-
has_non_batch_dim = len(ordered_shape) > 0
275-
target_shape = []
256+
self._module_hook_manager.gramian_accumulation_phase.value = True
276257

277-
if has_non_batch_dim:
278-
target_shape = [-1] + target_shape
258+
try:
259+
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))
260+
261+
def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
262+
return torch.autograd.grad(
263+
outputs=output,
264+
inputs=leaf_targets,
265+
grad_outputs=_grad_output,
266+
retain_graph=True,
267+
)
279268

280-
reshaped_output = ordered_output.reshape(target_shape)
281-
# There are four different cases for the shape of reshaped_output:
282-
# - Not batched and not non-batched: scalar of shape []
283-
# - Batched only: vector of shape [batch_size]
284-
# - Non-batched only: vector of shape [dim]
285-
# - Batched and non-batched: matrix of shape [dim, batch_size]
269+
output_dims = list(range(output.ndim))
270+
jac_output = DiagonalSparseTensor(torch.ones_like(output), output_dims * 2)
286271

287-
self._module_hook_manager.gramian_accumulation_phase.value = True
272+
vmapped_diff = differentiation
273+
for _ in output_dims:
274+
vmapped_diff = vmap(vmapped_diff)
288275

289-
try:
290-
square_gramian = self._compute_square_gramian(reshaped_output, has_non_batch_dim)
276+
_ = vmapped_diff(jac_output)
277+
278+
# If the gramian were None, then leaf_targets would be empty, so autograd.grad would
279+
# have failed. So gramian is necessarily a valid Tensor here.
280+
gramian = cast(Tensor, self._gramian_accumulator.gramian)
291281
finally:
292282
# Reset everything that has a state, even if the previous call raised an exception
293283
self._module_hook_manager.gramian_accumulation_phase.value = False
@@ -296,40 +286,4 @@ def compute_gramian(self, output: Tensor) -> Tensor:
296286
for gramian_computer in self._gramian_computers.values():
297287
gramian_computer.reset()
298288

299-
unordered_gramian = reshape_gramian(square_gramian, ordered_shape)
300-
301-
if self._batch_dim is not None:
302-
gramian = movedim_gramian(unordered_gramian, [-1], [self._batch_dim])
303-
else:
304-
gramian = unordered_gramian
305-
306-
return gramian
307-
308-
def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> Tensor:
309-
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))
310-
311-
def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
312-
return torch.autograd.grad(
313-
outputs=output,
314-
inputs=leaf_targets,
315-
grad_outputs=_grad_output,
316-
retain_graph=True,
317-
)
318-
319-
if has_non_batch_dim:
320-
# There is one non-batched dimension, it is the first one
321-
non_batch_dim_len = output.shape[0]
322-
identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype)
323-
ones = torch.ones_like(output[0])
324-
jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones)
325-
326-
_ = vmap(differentiation)(jac_output)
327-
else:
328-
grad_output = torch.ones_like(output)
329-
_ = differentiation(grad_output)
330-
331-
# If the gramian were None, then leaf_targets would be empty, so autograd.grad would
332-
# have failed. So gramian is necessarily a valid Tensor here.
333-
gramian = cast(Tensor, self._gramian_accumulator.gramian)
334-
335289
return gramian

src/torchjd/autogram/_gramian_computer.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from abc import ABC, abstractmethod
22
from typing import Optional
33

4+
import torch
45
from torch import Tensor
5-
from torch.utils._pytree import PyTree
66

77
from torchjd.autogram._jacobian_computer import JacobianComputer
88

@@ -13,8 +13,6 @@ def __call__(
1313
self,
1414
rg_outputs: tuple[Tensor, ...],
1515
grad_outputs: tuple[Tensor, ...],
16-
args: tuple[PyTree, ...],
17-
kwargs: dict[str, PyTree],
1816
) -> Optional[Tensor]:
1917
"""Compute what we can for a module and optionally return the gramian if it's ready."""
2018

@@ -30,8 +28,12 @@ def __init__(self, jacobian_computer):
3028
self.jacobian_computer = jacobian_computer
3129

3230
@staticmethod
33-
def _to_gramian(jacobian: Tensor) -> Tensor:
34-
return jacobian @ jacobian.T
31+
def _to_gramian(matrix: Tensor) -> Tensor:
32+
"""Contracts the last dimension of matrix to make it into a Gramian."""
33+
34+
indices = list(range(matrix.ndim))
35+
transposed_matrix = matrix.movedim(indices, indices[::-1])
36+
return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0]))
3537

3638

3739
class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
@@ -53,20 +55,17 @@ def track_forward_call(self) -> None:
5355
self.remaining_counter += 1
5456

5557
def __call__(
56-
self,
57-
rg_outputs: tuple[Tensor, ...],
58-
grad_outputs: tuple[Tensor, ...],
59-
args: tuple[PyTree, ...],
60-
kwargs: dict[str, PyTree],
58+
self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...]
6159
) -> Optional[Tensor]:
6260
"""Compute what we can for a module and optionally return the gramian if it's ready."""
6361

64-
jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
62+
batched_jacobian = self.jacobian_computer(rg_outputs, grad_outputs)
63+
jacobian = torch.func.debug_unwrap(batched_jacobian, recurse=True)
6564

6665
if self.summed_jacobian is None:
67-
self.summed_jacobian = jacobian_matrix
66+
self.summed_jacobian = jacobian
6867
else:
69-
self.summed_jacobian += jacobian_matrix
68+
self.summed_jacobian += jacobian
7069

7170
self.remaining_counter -= 1
7271

Lines changed: 7 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from abc import ABC, abstractmethod
2-
from collections.abc import Callable
3-
from typing import cast
1+
from abc import ABC
42

53
import torch
64
from torch import Tensor, nn
75
from torch.nn import Parameter
8-
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
6+
from torch.utils._pytree import tree_flatten
97

108
# Note about import from protected _pytree module:
119
# PyTorch maintainers plan to make pytree public (see
@@ -25,112 +23,26 @@ class JacobianComputer(ABC):
2523

2624
def __init__(self, module: nn.Module):
2725
self.module = module
28-
2926
self.rg_params = dict[str, Parameter]()
30-
self.frozen_params = dict[str, Parameter]()
3127

3228
for name, param in module.named_parameters(recurse=True):
3329
if param.requires_grad:
3430
self.rg_params[name] = param
35-
else:
36-
self.frozen_params[name] = param
37-
38-
def __call__(
39-
self,
40-
rg_outputs: tuple[Tensor, ...],
41-
grad_outputs: tuple[Tensor, ...],
42-
args: tuple[PyTree, ...],
43-
kwargs: dict[str, PyTree],
44-
) -> Tensor:
45-
# This makes __call__ vmappable.
46-
return ComputeModuleJacobians.apply(
47-
self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs
48-
)
4931

50-
@abstractmethod
51-
def _compute_jacobian(
52-
self,
53-
rg_outputs: tuple[Tensor, ...],
54-
grad_outputs: tuple[Tensor, ...],
55-
args: tuple[PyTree, ...],
56-
kwargs: dict[str, PyTree],
57-
) -> Tensor:
32+
def __call__(self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...]) -> Tensor:
5833
"""
59-
Computes and returns the Jacobian. The output must be a matrix (2D Tensor).
34+
Computes and returns the Jacobian. The output must be a generalized Jacobian with param
35+
dimensions grouped.
6036
"""
6137

6238

63-
class FunctionalJacobianComputer(JacobianComputer):
64-
"""
65-
JacobianComputer using the functional differentiation API. This requires to use vmap, so it's
66-
not compatible with every module, and it requires to have an extra forward pass to create the
67-
vjp function.
68-
"""
69-
70-
def _compute_jacobian(
71-
self,
72-
_: tuple[Tensor, ...],
73-
grad_outputs: tuple[Tensor, ...],
74-
args: tuple[PyTree, ...],
75-
kwargs: dict[str, PyTree],
76-
) -> Tensor:
77-
grad_outputs_in_dims = (0,) * len(grad_outputs)
78-
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
79-
kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs)
80-
in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims)
81-
vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)
82-
83-
return vmapped_vjp(grad_outputs, args, kwargs)
84-
85-
def _call_on_one_instance(
86-
self,
87-
grad_outputs_j: tuple[Tensor, ...],
88-
args_j: tuple[PyTree, ...],
89-
kwargs_j: dict[str, PyTree],
90-
) -> Tensor:
91-
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
92-
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
93-
# nn.Flatten) do not work equivalently if they're provided with a batch or with
94-
# an element of a batch. We thus always provide them with batches, just of a
95-
# different size.
96-
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
97-
kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j)
98-
grad_outputs_j_ = tuple(x.unsqueeze(0) for x in grad_outputs_j)
99-
100-
def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]:
101-
all_state = [
102-
cast(dict[str, Tensor], rg_params),
103-
dict(self.module.named_buffers()),
104-
cast(dict[str, Tensor], self.frozen_params),
105-
]
106-
output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j)
107-
flat_outputs = tree_flatten(output)[0]
108-
rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad)
109-
return rg_outputs
110-
111-
vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1]
112-
113-
# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
114-
# functional has a single primal which is dict(module.named_parameters()). We therefore take
115-
# the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
116-
gradients = vjp_func(grad_outputs_j_)[0]
117-
gradient = torch.cat([t.reshape(-1) for t in gradients.values()])
118-
return gradient
119-
120-
12139
class AutogradJacobianComputer(JacobianComputer):
12240
"""
12341
JacobianComputer using the autograd engine. The main advantage of using this method is that it
12442
doesn't require making an extra forward pass.
12543
"""
12644

127-
def _compute_jacobian(
128-
self,
129-
rg_outputs: tuple[Tensor, ...],
130-
grad_outputs: tuple[Tensor, ...],
131-
_: tuple[PyTree, ...],
132-
__: dict[str, PyTree],
133-
) -> Tensor:
45+
def __call__(self, rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...]) -> Tensor:
13446
flat_rg_params, ___ = tree_flatten(self.rg_params)
13547
grads = torch.autograd.grad(
13648
rg_outputs,
@@ -141,47 +53,4 @@ def _compute_jacobian(
14153
materialize_grads=True,
14254
)
14355
flattened_grads = torch.cat([g.reshape(-1) for g in grads])
144-
jacobian = flattened_grads.unsqueeze(0)
145-
return jacobian
146-
147-
148-
class ComputeModuleJacobians(torch.autograd.Function):
149-
@staticmethod
150-
def forward(
151-
compute_jacobian_fn: Callable[
152-
[tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Tensor
153-
],
154-
rg_outputs: tuple[Tensor, ...],
155-
grad_outputs: tuple[Tensor, ...],
156-
args: tuple[PyTree, ...],
157-
kwargs: dict[str, PyTree],
158-
) -> Tensor:
159-
# There is no non-batched dimension
160-
jacobian = compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs)
161-
return jacobian
162-
163-
@staticmethod
164-
def vmap(
165-
_,
166-
in_dims: tuple[None, None, tuple[int, ...], None, None],
167-
compute_jacobian_fn: Callable,
168-
rg_outputs: tuple[Tensor, ...],
169-
jac_outputs: tuple[Tensor, ...],
170-
args: tuple[PyTree, ...],
171-
kwargs: dict[str, PyTree],
172-
) -> tuple[Tensor, None]:
173-
# There is a non-batched dimension
174-
# We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension
175-
generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])(
176-
rg_outputs,
177-
jac_outputs,
178-
args,
179-
kwargs,
180-
)
181-
shape = generalized_jacobian.shape
182-
jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1])
183-
return jacobian, None
184-
185-
@staticmethod
186-
def setup_context(*_) -> None:
187-
pass
56+
return flattened_grads

0 commit comments

Comments
 (0)