22
33from abc import ABC , abstractmethod
44from collections .abc import Sequence
5- from typing import Generic
5+ from typing import TypeAlias
66
77from torch import Tensor
88
9- from ._tensor_dict import _A , _B , _C , EmptyTensorDict , _least_common_ancestor
9+ TensorDict : TypeAlias = dict [Tensor , Tensor ]
10+ # Some interesting cases of TensorDict that are worth defining informally (for performance reasons):
11+ # Gradients: A TensorDict in which the shape of each value must be the same as the shape of its
12+ # corresponding key.
13+ # Jacobians: A TensorDict in which the values must all have the same first dimension and the rest of
14+ # the shape of each value must be the same as the shape of its corresponding key.
15+ # GradientVectors: A TensorDict containing flattened gradients: the values must be vectors with the
16+ # same number of elements as their corresponding key.
17+ # JacobianMatrices: A TensorDict containing matrixified (flattened into matrix shape) jacobians: the
18+ # values must be matrices with a unique first dimension and with a second dimension equal to the
19+ # number of elements of their corresponding key.
1020
1121
1222class RequirementError (ValueError ):
@@ -15,23 +25,23 @@ class RequirementError(ValueError):
1525 pass
1626
1727
18- class Transform (Generic [ _B , _C ], ABC ):
28+ class Transform (ABC ):
1929 """
2030 Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
2131 descent backward phase. A transform maps a TensorDict to another.
2232 """
2333
24- def compose (self , other : Transform [ _A , _B ] ) -> Transform [ _A , _C ] :
34+ def compose (self , other : Transform ) -> Transform :
2535 return Composition (self , other )
2636
27- def conjunct (self , other : Transform [ _B , _C ] ) -> Transform [ _B , _C ] :
37+ def conjunct (self , other : Transform ) -> Transform :
2838 return Conjunction ([self , other ])
2939
3040 def __str__ (self ) -> str :
3141 return type (self ).__name__
3242
3343 @abstractmethod
34- def __call__ (self , input : _B ) -> _C :
44+ def __call__ (self , input : TensorDict ) -> TensorDict :
3545 """Applies the transform to the input."""
3646
3747 @abstractmethod
@@ -51,22 +61,22 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
5161 __or__ = conjunct
5262
5363
54- class Composition (Transform [ _B , _C ] ):
64+ class Composition (Transform ):
5565 """
5666 Transform corresponding to the composition of two transforms inner and outer.
5767
5868 :param inner: The transform to apply first, to the input.
5969 :param outer: The transform to apply second, to the result of ``inner``.
6070 """
6171
62- def __init__ (self , outer : Transform [ _A , _C ], inner : Transform [ _B , _A ] ):
72+ def __init__ (self , outer : Transform , inner : Transform ):
6373 self .outer = outer
6474 self .inner = inner
6575
6676 def __str__ (self ) -> str :
6777 return str (self .outer ) + " ∘ " + str (self .inner )
6878
69- def __call__ (self , input : _B ) -> _C :
79+ def __call__ (self , input : TensorDict ) -> TensorDict :
7080 intermediate = self .inner (input )
7181 return self .outer (intermediate )
7282
@@ -76,15 +86,15 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
7686 return output_keys
7787
7888
79- class Conjunction (Transform [ _B , _C ] ):
89+ class Conjunction (Transform ):
8090 """
8191 Transform applying several transforms to the same input, and combining the results (by union)
8292 into a single TensorDict.
8393
8494 :param transforms: The transforms to apply. Their outputs should have disjoint sets of keys.
8595 """
8696
87- def __init__ (self , transforms : Sequence [Transform [ _B , _C ] ]):
97+ def __init__ (self , transforms : Sequence [Transform ]):
8898 self .transforms = transforms
8999
90100 def __str__ (self ) -> str :
@@ -97,14 +107,11 @@ def __str__(self) -> str:
97107 strings .append (s )
98108 return "(" + " | " .join (strings ) + ")"
99109
100- def __call__ (self , tensor_dict : _B ) -> _C :
101- tensor_dicts = [transform (tensor_dict ) for transform in self .transforms ]
102- output_type : type [_B ] = EmptyTensorDict
103- output : _B = EmptyTensorDict ()
104- for tensor_dict in tensor_dicts :
105- output_type = _least_common_ancestor (output_type , type (tensor_dict ))
106- output |= tensor_dict
107- return output_type (output )
110+ def __call__ (self , tensor_dict : TensorDict ) -> TensorDict :
111+ union : TensorDict = {}
112+ for transform in self .transforms :
113+ union |= transform (tensor_dict )
114+ return union
108115
109116 def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
110117 output_keys_list = [key for t in self .transforms for key in t .check_keys (input_keys )]
0 commit comments