Skip to content

Commit 5acab1d

Browse files
authored
refactor(autojac): Remove reshaping to/from matrix in Jac (#420)
* Stop reshaping Jacobians in Jac * Factorize _get_vjp between Grad and Jac * Add changelog entry
1 parent 6431bbf commit 5acab1d

File tree

4 files changed

+33
-55
lines changed

4 files changed

+33
-55
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ changes that do not affect the user.
3434

3535
### Changed
3636

37+
- Removed an unnecessary internal reshape when computing Jacobians. This should have no effect but a
38+
slight performance improvement in `autojac`.
3739
- Revamped documentation.
3840
- Made `backward` and `mtl_backward` importable from `torchjd.autojac` (like it was prior to 0.7.0).
3941
- Deprecated importing `backward` and `mtl_backward` from `torchjd` directly.

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import Sequence
33

4+
import torch
45
from torch import Tensor
56

67
from ._base import RequirementError, TensorDict, Transform
8+
from ._materialize import materialize
79
from ._ordered_set import OrderedSet
810

911

@@ -61,3 +63,15 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
6163
f"outputs {outputs}."
6264
)
6365
return set(self.inputs)
66+
67+
def _get_vjp(self, grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]:
68+
optional_grads = torch.autograd.grad(
69+
self.outputs,
70+
self.inputs,
71+
grad_outputs=grad_outputs,
72+
retain_graph=retain_graph,
73+
create_graph=self.create_graph,
74+
allow_unused=True,
75+
)
76+
grads = materialize(optional_grads, inputs=self.inputs)
77+
return grads

src/torchjd/autojac/_transform/_grad.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torch import Tensor
55

66
from ._differentiate import Differentiate
7-
from ._materialize import materialize
87
from ._ordered_set import OrderedSet
98

109

@@ -54,13 +53,5 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
5453
if len(self.outputs) == 0:
5554
return tuple([torch.zeros_like(input) for input in self.inputs])
5655

57-
optional_grads = torch.autograd.grad(
58-
self.outputs,
59-
self.inputs,
60-
grad_outputs=grad_outputs,
61-
retain_graph=self.retain_graph,
62-
create_graph=self.create_graph,
63-
allow_unused=True,
64-
)
65-
grads = materialize(optional_grads, self.inputs)
56+
grads = self._get_vjp(grad_outputs, self.retain_graph)
6657
return grads

src/torchjd/autojac/_transform/_jac.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import math
22
from collections.abc import Callable, Sequence
33
from functools import partial
4-
from itertools import accumulate
54

65
import torch
7-
from torch import Size, Tensor
6+
from torch import Tensor
87

98
from ._differentiate import Differentiate
10-
from ._materialize import materialize
119
from ._ordered_set import OrderedSet
1210

1311

@@ -69,53 +67,37 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
6967
]
7068
)
7169

72-
def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
73-
optional_grads = torch.autograd.grad(
74-
self.outputs,
75-
self.inputs,
76-
grad_outputs=grad_outputs,
77-
retain_graph=retain_graph,
78-
create_graph=self.create_graph,
79-
allow_unused=True,
80-
)
81-
grads = materialize(optional_grads, inputs=self.inputs)
82-
return torch.concatenate([grad.reshape([-1]) for grad in grads])
83-
8470
# If the jac_outputs are correct, this value should be the same for all jac_outputs.
8571
m = jac_outputs[0].shape[0]
8672
max_chunk_size = self.chunk_size if self.chunk_size is not None else m
8773
n_chunks = math.ceil(m / max_chunk_size)
8874

89-
# List of tensors of shape [k_i, n] where the k_i's sum to m
90-
jac_matrix_chunks = []
75+
# One tuple per chunk (i), with one value per input (j), of shape [k_i] + shape[j],
76+
# where k_i is the number of rows in the chunk (the k_i's sum to m)
77+
jacs_chunks: list[tuple[Tensor, ...]] = []
9178

9279
# First differentiations: always retain graph
93-
get_vjp_retain = partial(_get_vjp, retain_graph=True)
80+
get_vjp_retain = partial(self._get_vjp, retain_graph=True)
9481
for i in range(n_chunks - 1):
9582
start = i * max_chunk_size
9683
end = (i + 1) * max_chunk_size
9784
jac_outputs_chunk = [jac_output[start:end] for jac_output in jac_outputs]
98-
jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_retain))
85+
jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_retain))
9986

10087
# Last differentiation: retain the graph only if self.retain_graph==True
101-
get_vjp_last = partial(_get_vjp, retain_graph=self.retain_graph)
88+
get_vjp_last = partial(self._get_vjp, retain_graph=self.retain_graph)
10289
start = (n_chunks - 1) * max_chunk_size
10390
jac_outputs_chunk = [jac_output[start:] for jac_output in jac_outputs]
104-
jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_last))
105-
106-
jac_matrix = torch.vstack(jac_matrix_chunks)
107-
lengths = [input.numel() for input in self.inputs]
108-
jac_matrices = _extract_sub_matrices(jac_matrix, lengths)
109-
110-
shapes = [input.shape for input in self.inputs]
111-
jacs = _reshape_matrices(jac_matrices, shapes)
91+
jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last))
11292

113-
return tuple(jacs)
93+
n_inputs = len(self.inputs)
94+
jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs))
95+
return jacs
11496

11597

116-
def _get_jac_matrix_chunk(
117-
jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], Tensor]
118-
) -> Tensor:
98+
def _get_jacs_chunk(
99+
jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], tuple[Tensor, ...]]
100+
) -> tuple[Tensor, ...]:
119101
"""
120102
Computes the jacobian matrix chunk corresponding to the provided get_vjp function, either by
121103
calling get_vjp directly or by wrapping it into a call to ``torch.vmap``, depending on the shape
@@ -126,18 +108,7 @@ def _get_jac_matrix_chunk(
126108
chunk_size = jac_outputs_chunk[0].shape[0]
127109
if chunk_size == 1:
128110
grad_outputs = [tensor.squeeze(0) for tensor in jac_outputs_chunk]
129-
gradient_vector = get_vjp(grad_outputs)
130-
return gradient_vector.unsqueeze(0)
111+
gradients = get_vjp(grad_outputs)
112+
return tuple(gradient.unsqueeze(0) for gradient in gradients)
131113
else:
132114
return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk)
133-
134-
135-
def _extract_sub_matrices(matrix: Tensor, lengths: Sequence[int]) -> list[Tensor]:
136-
cumulative_lengths = [*accumulate(lengths)]
137-
start_indices = [0] + cumulative_lengths[:-1]
138-
end_indices = cumulative_lengths
139-
return [matrix[:, start:end] for start, end in zip(start_indices, end_indices)]
140-
141-
142-
def _reshape_matrices(matrices: Sequence[Tensor], shapes: Sequence[Size]) -> Sequence[Tensor]:
143-
return [matrix.view((matrix.shape[0],) + shape) for matrix, shape in zip(matrices, shapes)]

0 commit comments

Comments
 (0)