Skip to content

Commit 3a27958

Browse files
authored
refactor(autojac): Fix genericity of Transforms (#369)
* Make _B contravariant and _C covariant * Fix generic parametrization in Composition and Conjunction * Make Stack contravariant in its input type
1 parent 28405e9 commit 3a27958

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

src/torchjd/_autojac/_transform/_base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,22 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
5151
__or__ = conjunct
5252

5353

54-
class Composition(Transform[_A, _C]):
54+
class Composition(Transform[_B, _C]):
5555
"""
5656
Transform corresponding to the composition of two transforms inner and outer.
5757
5858
:param inner: The transform to apply first, to the input.
5959
:param outer: The transform to apply second, to the result of ``inner``.
6060
"""
6161

62-
def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
62+
def __init__(self, outer: Transform[_A, _C], inner: Transform[_B, _A]):
6363
self.outer = outer
6464
self.inner = inner
6565

6666
def __str__(self) -> str:
6767
return str(self.outer) + " ∘ " + str(self.inner)
6868

69-
def __call__(self, input: _A) -> _C:
69+
def __call__(self, input: _B) -> _C:
7070
intermediate = self.inner(input)
7171
return self.outer(intermediate)
7272

@@ -76,15 +76,15 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
7676
return output_keys
7777

7878

79-
class Conjunction(Transform[_A, _B]):
79+
class Conjunction(Transform[_B, _C]):
8080
"""
8181
Transform applying several transforms to the same input, and combining the results (by union)
8282
into a single TensorDict.
8383
8484
:param transforms: The transforms to apply. Their outputs should have disjoint sets of keys.
8585
"""
8686

87-
def __init__(self, transforms: Sequence[Transform[_A, _B]]):
87+
def __init__(self, transforms: Sequence[Transform[_B, _C]]):
8888
self.transforms = transforms
8989

9090
def __str__(self) -> str:
@@ -97,10 +97,10 @@ def __str__(self) -> str:
9797
strings.append(s)
9898
return "(" + " | ".join(strings) + ")"
9999

100-
def __call__(self, tensor_dict: _A) -> _B:
100+
def __call__(self, tensor_dict: _B) -> _C:
101101
tensor_dicts = [transform(tensor_dict) for transform in self.transforms]
102-
output_type: type[_A] = EmptyTensorDict
103-
output: _A = EmptyTensorDict()
102+
output_type: type[_B] = EmptyTensorDict
103+
output: _B = EmptyTensorDict()
104104
for tensor_dict in tensor_dicts:
105105
output_type = _least_common_ancestor(output_type, type(tensor_dict))
106106
output |= tensor_dict

src/torchjd/_autojac/_transform/_stack.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from ._base import Transform
77
from ._materialize import materialize
8-
from ._tensor_dict import _A, Gradients, Jacobians
8+
from ._tensor_dict import _B, Gradients, Jacobians
99

1010

11-
class Stack(Transform[_A, Jacobians]):
11+
class Stack(Transform[_B, Jacobians]):
1212
"""
1313
Transform applying several transforms to the same input, and combining the results (by stacking)
1414
into a single TensorDict.
@@ -20,10 +20,10 @@ class Stack(Transform[_A, Jacobians]):
2020
at the positions corresponding to those dicts.
2121
"""
2222

23-
def __init__(self, transforms: Sequence[Transform[_A, Gradients]]):
23+
def __init__(self, transforms: Sequence[Transform[_B, Gradients]]):
2424
self.transforms = transforms
2525

26-
def __call__(self, input: _A) -> Jacobians:
26+
def __call__(self, input: _B) -> Jacobians:
2727
results = [transform(input) for transform in self.transforms]
2828
result = _stack(results)
2929
return result

src/torchjd/_autojac/_transform/_tensor_dict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,5 +182,5 @@ def _check_corresponding_numel(key: Tensor, value: Tensor, dim: int) -> None:
182182

183183

184184
_A = TypeVar("_A", bound=TensorDict)
185-
_B = TypeVar("_B", bound=TensorDict)
186-
_C = TypeVar("_C", bound=TensorDict)
185+
_B = TypeVar("_B", bound=TensorDict, contravariant=True)
186+
_C = TypeVar("_C", bound=TensorDict, covariant=True)

0 commit comments

Comments
 (0)