@@ -51,22 +51,22 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
5151 __or__ = conjunct
5252
5353
54- class Composition (Transform [_A , _C ]):
54+ class Composition (Transform [_B , _C ]):
5555 """
5656 Transform corresponding to the composition of two transforms inner and outer.
5757
5858 :param inner: The transform to apply first, to the input.
5959 :param outer: The transform to apply second, to the result of ``inner``.
6060 """
6161
62- def __init__ (self , outer : Transform [_B , _C ], inner : Transform [_A , _B ]):
62+ def __init__ (self , outer : Transform [_A , _C ], inner : Transform [_B , _A ]):
6363 self .outer = outer
6464 self .inner = inner
6565
6666 def __str__ (self ) -> str :
6767 return str (self .outer ) + " ∘ " + str (self .inner )
6868
69- def __call__ (self , input : _A ) -> _C :
69+ def __call__ (self , input : _B ) -> _C :
7070 intermediate = self .inner (input )
7171 return self .outer (intermediate )
7272
@@ -76,15 +76,15 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
7676 return output_keys
7777
7878
79- class Conjunction (Transform [_A , _B ]):
79+ class Conjunction (Transform [_B , _C ]):
8080 """
8181 Transform applying several transforms to the same input, and combining the results (by union)
8282 into a single TensorDict.
8383
8484 :param transforms: The transforms to apply. Their outputs should have disjoint sets of keys.
8585 """
8686
87- def __init__ (self , transforms : Sequence [Transform [_A , _B ]]):
87+ def __init__ (self , transforms : Sequence [Transform [_B , _C ]]):
8888 self .transforms = transforms
8989
9090 def __str__ (self ) -> str :
@@ -97,10 +97,10 @@ def __str__(self) -> str:
9797 strings .append (s )
9898 return "(" + " | " .join (strings ) + ")"
9999
100- def __call__ (self , tensor_dict : _A ) -> _B :
100+ def __call__ (self , tensor_dict : _B ) -> _C :
101101 tensor_dicts = [transform (tensor_dict ) for transform in self .transforms ]
102- output_type : type [_A ] = EmptyTensorDict
103- output : _A = EmptyTensorDict ()
102+ output_type : type [_B ] = EmptyTensorDict
103+ output : _B = EmptyTensorDict ()
104104 for tensor_dict in tensor_dicts :
105105 output_type = _least_common_ancestor (output_type , type (tensor_dict ))
106106 output |= tensor_dict
0 commit comments