7272def _apply_to_composite (function ):
7373 def new_fun (self , observation_spec ):
7474 if isinstance (observation_spec , CompositeSpec ):
75- d = copy ( observation_spec ._specs )
75+ d = observation_spec ._specs
7676 for key_in , key_out in zip (self .keys_in , self .keys_out ):
7777 if key_in in observation_spec .keys ():
7878 d [key_out ] = function (self , observation_spec [key_in ])
@@ -506,7 +506,9 @@ def __getattr__(self, attr: str) -> Any:
506506 )
507507
508508 def __repr__ (self ) -> str :
509- return f"TransformedEnv(env={ self .base_env } , transform={ self .transform } )"
509+ env_str = indent (f"env={ self .base_env } " , 4 * " " )
510+ t_str = indent (f"transform={ self .transform } " , 4 * " " )
511+ return f"TransformedEnv(\n { env_str } ,\n { t_str } )"
510512
511513 def _erase_metadata (self ):
512514 if self .cache_specs :
@@ -621,7 +623,9 @@ def __getitem__(self, item: Union[int, slice, List]) -> Union:
621623 transform = self .transforms
622624 transform = transform [item ]
623625 if not isinstance (transform , Transform ):
624- return Compose (* self .transforms [item ])
626+ out = Compose (* self .transforms [item ])
627+ out .set_parent (self .parent )
628+ return out
625629 return transform
626630
627631 def dump (self , ** kwargs ) -> None :
@@ -737,7 +741,7 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor:
737741
738742 @_apply_to_composite
739743 def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
740- self ._pixel_observation (observation_spec )
744+ observation_spec = self ._pixel_observation (deepcopy ( observation_spec ) )
741745 observation_spec .shape = torch .Size (
742746 [
743747 * observation_spec .shape [:- 3 ],
@@ -747,13 +751,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
747751 ]
748752 )
749753 observation_spec .dtype = self .dtype
750- observation_spec = observation_spec
751754 return observation_spec
752755
753756 def _pixel_observation (self , spec : TensorSpec ) -> None :
754- if isinstance (spec , BoundedTensorSpec ):
757+ if isinstance (spec . space , ContinuousBox ):
755758 spec .space .maximum = self ._apply_transform (spec .space .maximum )
756759 spec .space .minimum = self ._apply_transform (spec .space .minimum )
760+ return spec
757761
758762
759763class RewardClipping (Transform ):
@@ -899,6 +903,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
899903
900904 @_apply_to_composite
901905 def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
906+ observation_spec = deepcopy (observation_spec )
902907 space = observation_spec .space
903908 if isinstance (space , ContinuousBox ):
904909 space .minimum = self ._apply_transform (space .minimum )
@@ -962,7 +967,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
962967 }
963968 )
964969 else :
965- _observation_spec = observation_spec
970+ _observation_spec = deepcopy (observation_spec )
971+
966972 space = _observation_spec .space
967973 if isinstance (space , ContinuousBox ):
968974 space .minimum = self ._apply_transform (space .minimum )
@@ -1019,6 +1025,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
10191025
10201026 @_apply_to_composite
10211027 def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1028+ observation_spec = deepcopy (observation_spec )
10221029 space = observation_spec .space
10231030 if isinstance (space , ContinuousBox ):
10241031 space .minimum = self ._apply_transform (space .minimum )
@@ -1122,25 +1129,26 @@ def _transform_spec(self, spec: TensorSpec) -> None:
11221129 spec .shape = space .minimum .shape
11231130 else :
11241131 spec .shape = self ._apply_transform (torch .zeros (spec .shape )).shape
1132+ return spec
11251133
11261134 def transform_action_spec (self , action_spec : TensorSpec ) -> TensorSpec :
11271135 if "action" in self .keys_inv_in :
1128- self ._transform_spec (action_spec )
1136+ action_spec = self ._transform_spec (deepcopy ( action_spec ) )
11291137 return action_spec
11301138
11311139 def transform_input_spec (self , input_spec : TensorSpec ) -> TensorSpec :
11321140 for key in self .keys_inv_in :
1133- self ._transform_spec (input_spec [key ])
1141+ input_spec = self ._transform_spec (deepcopy ( input_spec [key ]) )
11341142 return input_spec
11351143
11361144 def transform_reward_spec (self , reward_spec : TensorSpec ) -> TensorSpec :
11371145 if "reward" in self .keys_in :
1138- self ._transform_spec (reward_spec )
1146+ reward_spec = self ._transform_spec (deepcopy ( reward_spec ) )
11391147 return reward_spec
11401148
11411149 @_apply_to_composite
11421150 def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1143- self ._transform_spec (observation_spec )
1151+ observation_spec = self ._transform_spec (deepcopy ( observation_spec ) )
11441152 return observation_spec
11451153
11461154 def __repr__ (self ) -> str :
@@ -1207,6 +1215,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
12071215
12081216 @_apply_to_composite
12091217 def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1218+ observation_spec = deepcopy (observation_spec )
12101219 space = observation_spec .space
12111220 if isinstance (space , ContinuousBox ):
12121221 space .minimum = self ._apply_transform (space .minimum )
@@ -1295,6 +1304,7 @@ def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
12951304
12961305 @_apply_to_composite
12971306 def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1307+ observation_spec = deepcopy (observation_spec )
12981308 space = observation_spec .space
12991309 if isinstance (space , ContinuousBox ):
13001310 space .minimum = self ._apply_transform (space .minimum )
0 commit comments