Skip to content

Commit 849d7e7

Browse files
authored
Remove TensorDict classes (#382)
* Remove TensorDicts and update Transforms and their usages accordingly * Add changelog entry
1 parent 75ef97a commit 849d7e7

27 files changed

+185
-542
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ changes that do not affect the user.
2525
`inf` or `-inf` values. This check was costly in memory and in time for large matrices so this
2626
should improve performance. However, if the optimization diverges for some reason (for instance
2727
due to a too large learning rate), the resulting exceptions may come from other sources.
28+
- Removed some runtime checks on the shapes of the internal tensors used by the `autojac` engine.
29+
This should lead to a small performance improvement.
2830

2931
### Fixed
3032

src/torchjd/_autojac/_backward.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,7 @@
44

55
from torchjd.aggregation import Aggregator
66

7-
from ._transform import (
8-
Accumulate,
9-
Aggregate,
10-
Diagonalize,
11-
EmptyTensorDict,
12-
Init,
13-
Jac,
14-
OrderedSet,
15-
Transform,
16-
)
7+
from ._transform import Accumulate, Aggregate, Diagonalize, Init, Jac, OrderedSet, Transform
178
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
189

1910

@@ -95,7 +86,7 @@ def backward(
9586
parallel_chunk_size=parallel_chunk_size,
9687
)
9788

98-
backward_transform(EmptyTensorDict())
89+
backward_transform({})
9990

10091

10192
def _create_transform(
@@ -104,7 +95,7 @@ def _create_transform(
10495
inputs: OrderedSet[Tensor],
10596
retain_graph: bool,
10697
parallel_chunk_size: int | None,
107-
) -> Transform[EmptyTensorDict, EmptyTensorDict]:
98+
) -> Transform:
10899
"""Creates the Jacobian descent backward transform."""
109100

110101
# Transform that creates gradient outputs containing only ones.

src/torchjd/_autojac/_mtl_backward.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,7 @@
44

55
from torchjd.aggregation import Aggregator
66

7-
from ._transform import (
8-
Accumulate,
9-
Aggregate,
10-
EmptyTensorDict,
11-
Grad,
12-
Gradients,
13-
Init,
14-
Jac,
15-
OrderedSet,
16-
Select,
17-
Stack,
18-
Transform,
19-
)
7+
from ._transform import Accumulate, Aggregate, Grad, Init, Jac, OrderedSet, Select, Stack, Transform
208
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
219

2210

@@ -114,7 +102,7 @@ def mtl_backward(
114102
parallel_chunk_size=parallel_chunk_size,
115103
)
116104

117-
backward_transform(EmptyTensorDict())
105+
backward_transform({})
118106

119107

120108
def _create_transform(
@@ -125,7 +113,7 @@ def _create_transform(
125113
shared_params: OrderedSet[Tensor],
126114
retain_graph: bool,
127115
parallel_chunk_size: int | None,
128-
) -> Transform[EmptyTensorDict, EmptyTensorDict]:
116+
) -> Transform:
129117
"""
130118
Creates the backward transform for a multi-task learning problem. It is a hybrid between
131119
Jacobian descent (for shared parameters) and multiple gradient descent branches (for
@@ -166,7 +154,7 @@ def _create_task_transform(
166154
task_params: OrderedSet[Tensor],
167155
loss: OrderedSet[Tensor], # contains a single scalar loss
168156
retain_graph: bool,
169-
) -> Transform[EmptyTensorDict, Gradients]:
157+
) -> Transform:
170158
# Tensors with respect to which we compute the gradients.
171159
to_differentiate = task_params + features
172160

@@ -179,10 +167,10 @@ def _create_task_transform(
179167

180168
# Transform that accumulates the gradients w.r.t. the task-specific parameters into their
181169
# .grad fields.
182-
accumulate = Accumulate() << Select[Gradients](task_params)
170+
accumulate = Accumulate() << Select(task_params)
183171

184172
# Transform that backpropagates the gradients of the losses w.r.t. the features.
185-
backpropagate = Select[Gradients](features)
173+
backpropagate = Select(features)
186174

187175
# Transform that accumulates the gradient of the losses w.r.t. the task-specific parameters into
188176
# their .grad fields and backpropagates the gradient of the losses w.r.t. to the features.

src/torchjd/_autojac/_transform/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,3 @@
88
from ._ordered_set import OrderedSet
99
from ._select import Select
1010
from ._stack import Stack
11-
from ._tensor_dict import (
12-
EmptyTensorDict,
13-
Gradients,
14-
GradientVectors,
15-
JacobianMatrices,
16-
Jacobians,
17-
TensorDict,
18-
)

src/torchjd/_autojac/_transform/_accumulate.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from torch import Tensor
22

3-
from ._base import Transform
4-
from ._tensor_dict import EmptyTensorDict, Gradients
3+
from ._base import TensorDict, Transform
54

65

7-
class Accumulate(Transform[Gradients, EmptyTensorDict]):
8-
"""Transform that accumulates gradients with respect to keys into their ``grad`` field."""
6+
class Accumulate(Transform):
7+
"""
8+
Transform from Gradients to {} that accumulates gradients with respect to keys into their
9+
``grad`` field.
10+
"""
911

10-
def __call__(self, gradients: Gradients) -> EmptyTensorDict:
12+
def __call__(self, gradients: TensorDict) -> TensorDict:
1113
for key in gradients.keys():
1214
_check_expects_grad(key)
1315
if hasattr(key, "grad") and key.grad is not None:
@@ -19,7 +21,7 @@ def __call__(self, gradients: Gradients) -> EmptyTensorDict:
1921
# (in case it was obtained via create_graph=True and a differentiable aggregator).
2022
key.grad = gradients[key].clone()
2123

22-
return EmptyTensorDict()
24+
return {}
2325

2426
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
2527
return set()

src/torchjd/_autojac/_transform/_aggregate.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88
from torchjd.aggregation import Aggregator
99

10-
from ._base import RequirementError, Transform
10+
from ._base import RequirementError, TensorDict, Transform
1111
from ._ordered_set import OrderedSet
12-
from ._tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians
1312

1413
_KeyType = TypeVar("_KeyType", bound=Hashable)
1514
_ValueType = TypeVar("_ValueType")
1615

1716

18-
class Aggregate(Transform[Jacobians, Gradients]):
17+
class Aggregate(Transform):
1918
"""
2019
Transform aggregating Jacobians into Gradients.
2120
@@ -35,14 +34,14 @@ def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
3534
self._aggregator_str = str(aggregator)
3635
self.transform = reshape << aggregate_matrices << matrixify
3736

38-
def __call__(self, input: Jacobians) -> Gradients:
37+
def __call__(self, input: TensorDict) -> TensorDict:
3938
return self.transform(input)
4039

4140
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
4241
return self.transform.check_keys(input_keys)
4342

4443

45-
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
44+
class _AggregateMatrices(Transform):
4645
"""
4746
Transform aggregating JacobiansMatrices into GradientsVectors.
4847
@@ -57,7 +56,7 @@ def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
5756
self.key_order = key_order
5857
self.aggregator = aggregator
5958

60-
def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
59+
def __call__(self, jacobian_matrices: TensorDict) -> TensorDict:
6160
"""
6261
Concatenates the provided ``jacobian_matrices`` into a single matrix and aggregates it using
6362
the ``aggregator``. Returns the dictionary mapping each key from ``jacobian_matrices`` to
@@ -92,15 +91,15 @@ def _select_ordered_subdict(
9291
@staticmethod
9392
def _aggregate_group(
9493
jacobian_matrices: OrderedDict[Tensor, Tensor], aggregator: Aggregator
95-
) -> GradientVectors:
94+
) -> TensorDict:
9695
"""
9796
Unites the jacobian matrices and aggregates them using an
9897
:class:`~torchjd.aggregation._aggregator_bases.Aggregator`. Returns the obtained gradient
9998
vectors.
10099
"""
101100

102101
if len(jacobian_matrices) == 0:
103-
return EmptyTensorDict()
102+
return {}
104103

105104
united_jacobian_matrix = _AggregateMatrices._unite(jacobian_matrices)
106105
united_gradient_vector = aggregator(united_jacobian_matrix)
@@ -114,39 +113,39 @@ def _unite(jacobian_matrices: OrderedDict[Tensor, Tensor]) -> Tensor:
114113
@staticmethod
115114
def _disunite(
116115
united_gradient_vector: Tensor, jacobian_matrices: OrderedDict[Tensor, Tensor]
117-
) -> GradientVectors:
116+
) -> TensorDict:
118117
gradient_vectors = {}
119118
start = 0
120119
for key, jacobian_matrix in jacobian_matrices.items():
121120
end = start + jacobian_matrix.shape[1]
122121
current_gradient_vector = united_gradient_vector[start:end]
123122
gradient_vectors[key] = current_gradient_vector
124123
start = end
125-
return GradientVectors(gradient_vectors)
124+
return gradient_vectors
126125

127126

128-
class _Matrixify(Transform[Jacobians, JacobianMatrices]):
127+
class _Matrixify(Transform):
129128
"""Transform reshaping Jacobians into JacobianMatrices."""
130129

131-
def __call__(self, jacobians: Jacobians) -> JacobianMatrices:
130+
def __call__(self, jacobians: TensorDict) -> TensorDict:
132131
jacobian_matrices = {
133132
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
134133
}
135-
return JacobianMatrices(jacobian_matrices)
134+
return jacobian_matrices
136135

137136
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
138137
return input_keys
139138

140139

141-
class _Reshape(Transform[GradientVectors, Gradients]):
140+
class _Reshape(Transform):
142141
"""Transform reshaping GradientVectors into Gradients."""
143142

144-
def __call__(self, gradient_vectors: GradientVectors) -> Gradients:
143+
def __call__(self, gradient_vectors: TensorDict) -> TensorDict:
145144
gradients = {
146145
key: gradient_vector.view(key.shape)
147146
for key, gradient_vector in gradient_vectors.items()
148147
}
149-
return Gradients(gradients)
148+
return gradients
150149

151150
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
152151
return input_keys

src/torchjd/_autojac/_transform/_base.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,21 @@
22

33
from abc import ABC, abstractmethod
44
from collections.abc import Sequence
5-
from typing import Generic
5+
from typing import TypeAlias
66

77
from torch import Tensor
88

9-
from ._tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor
9+
TensorDict: TypeAlias = dict[Tensor, Tensor]
10+
# Some interesting cases of TensorDict that are worth defining informally (for performance reasons):
11+
# Gradients: A TensorDict in which the shape of each value must be the same as the shape of its
12+
# corresponding key.
13+
# Jacobians: A TensorDict in which the values must all have the same first dimension and the rest of
14+
# the shape of each value must be the same as the shape of its corresponding key.
15+
# GradientVectors: A TensorDict containing flattened gradients: the values must be vectors with the
16+
# same number of elements as their corresponding key.
17+
# JacobianMatrices: A TensorDict containing matrixified (flattened into matrix shape) jacobians: the
18+
# values must be matrices with a unique first dimension and with a second dimension equal to the
19+
# number of elements of their corresponding key.
1020

1121

1222
class RequirementError(ValueError):
@@ -15,23 +25,23 @@ class RequirementError(ValueError):
1525
pass
1626

1727

18-
class Transform(Generic[_B, _C], ABC):
28+
class Transform(ABC):
1929
"""
2030
Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
2131
descent backward phase. A transform maps a TensorDict to another.
2232
"""
2333

24-
def compose(self, other: Transform[_A, _B]) -> Transform[_A, _C]:
34+
def compose(self, other: Transform) -> Transform:
2535
return Composition(self, other)
2636

27-
def conjunct(self, other: Transform[_B, _C]) -> Transform[_B, _C]:
37+
def conjunct(self, other: Transform) -> Transform:
2838
return Conjunction([self, other])
2939

3040
def __str__(self) -> str:
3141
return type(self).__name__
3242

3343
@abstractmethod
34-
def __call__(self, input: _B) -> _C:
44+
def __call__(self, input: TensorDict) -> TensorDict:
3545
"""Applies the transform to the input."""
3646

3747
@abstractmethod
@@ -51,22 +61,22 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
5161
__or__ = conjunct
5262

5363

54-
class Composition(Transform[_B, _C]):
64+
class Composition(Transform):
5565
"""
5666
Transform corresponding to the composition of two transforms inner and outer.
5767
5868
:param inner: The transform to apply first, to the input.
5969
:param outer: The transform to apply second, to the result of ``inner``.
6070
"""
6171

62-
def __init__(self, outer: Transform[_A, _C], inner: Transform[_B, _A]):
72+
def __init__(self, outer: Transform, inner: Transform):
6373
self.outer = outer
6474
self.inner = inner
6575

6676
def __str__(self) -> str:
6777
return str(self.outer) + " ∘ " + str(self.inner)
6878

69-
def __call__(self, input: _B) -> _C:
79+
def __call__(self, input: TensorDict) -> TensorDict:
7080
intermediate = self.inner(input)
7181
return self.outer(intermediate)
7282

@@ -76,15 +86,15 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
7686
return output_keys
7787

7888

79-
class Conjunction(Transform[_B, _C]):
89+
class Conjunction(Transform):
8090
"""
8191
Transform applying several transforms to the same input, and combining the results (by union)
8292
into a single TensorDict.
8393
8494
:param transforms: The transforms to apply. Their outputs should have disjoint sets of keys.
8595
"""
8696

87-
def __init__(self, transforms: Sequence[Transform[_B, _C]]):
97+
def __init__(self, transforms: Sequence[Transform]):
8898
self.transforms = transforms
8999

90100
def __str__(self) -> str:
@@ -97,14 +107,11 @@ def __str__(self) -> str:
97107
strings.append(s)
98108
return "(" + " | ".join(strings) + ")"
99109

100-
def __call__(self, tensor_dict: _B) -> _C:
101-
tensor_dicts = [transform(tensor_dict) for transform in self.transforms]
102-
output_type: type[_B] = EmptyTensorDict
103-
output: _B = EmptyTensorDict()
104-
for tensor_dict in tensor_dicts:
105-
output_type = _least_common_ancestor(output_type, type(tensor_dict))
106-
output |= tensor_dict
107-
return output_type(output)
110+
def __call__(self, tensor_dict: TensorDict) -> TensorDict:
111+
union: TensorDict = {}
112+
for transform in self.transforms:
113+
union |= transform(tensor_dict)
114+
return union
108115

109116
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
110117
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]

0 commit comments

Comments
 (0)