@@ -129,9 +129,14 @@ class DiscreteActionVecMockEnv(_MockEnv):
129129 )
130130 action_spec = OneHotDiscreteTensorSpec (7 )
131131 reward_spec = UnboundedContinuousTensorSpec ()
132+
132133 from_pixels = False
133134
134135 out_key = "observation"
136+ _out_key = "observation_orig"
137+ input_spec = CompositeSpec (
138+ ** {_out_key : observation_spec ["next_observation" ], "action" : action_spec }
139+ )
135140
136141 def _get_in_obs (self , obs ):
137142 return obs
@@ -145,6 +150,7 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
145150 tensordict = tensordict .select ().set (
146151 "next_" + self .out_key , self ._get_out_obs (state )
147152 )
153+ tensordict = tensordict .set ("next_" + self ._out_key , self ._get_out_obs (state ))
148154 tensordict .set ("done" , torch .zeros (* tensordict .shape , 1 , dtype = torch .bool ))
149155 return tensordict
150156
@@ -157,12 +163,12 @@ def _step(
157163 assert (a .sum (- 1 ) == 1 ).all ()
158164 assert not self .is_done , "trying to execute step in done env"
159165
160- obs = (
161- self ._get_in_obs (self .current_tensordict .get (self .out_key ))
162- + a / self .maxstep
163- )
166+ obs = self ._get_in_obs (tensordict .get (self ._out_key )) + a / self .maxstep
164167 tensordict = tensordict .select () # empty tensordict
168+
165169 tensordict .set ("next_" + self .out_key , self ._get_out_obs (obs ))
170+ tensordict .set ("next_" + self ._out_key , self ._get_out_obs (obs ))
171+
166172 done = torch .isclose (obs , torch .ones_like (obs ) * (self .counter + 1 ))
167173 reward = done .any (- 1 ).unsqueeze (- 1 )
168174 # set done to False
@@ -182,6 +188,10 @@ class ContinuousActionVecMockEnv(_MockEnv):
182188 from_pixels = False
183189
184190 out_key = "observation"
191+ _out_key = "observation_orig"
192+ input_spec = CompositeSpec (
193+ ** {_out_key : observation_spec ["next_observation" ], "action" : action_spec }
194+ )
185195
186196 def _get_in_obs (self , obs ):
187197 return obs
@@ -193,9 +203,9 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
193203 self .counter += 1
194204 self .step_count = 0
195205 state = torch .zeros (self .size ) + self .counter
196- tensordict = tensordict .select (). set (
197- "next_" + self .out_key , self ._get_out_obs (state )
198- )
206+ tensordict = tensordict .select ()
207+ tensordict . set ( "next_" + self .out_key , self ._get_out_obs (state ) )
208+ tensordict . set ( "next_" + self . _out_key , self . _get_out_obs ( state ) )
199209 tensordict .set ("done" , torch .zeros (* tensordict .shape , 1 , dtype = torch .bool ))
200210 return tensordict
201211
@@ -208,11 +218,12 @@ def _step(
208218 a = tensordict .get ("action" )
209219 assert not self .is_done , "trying to execute step in done env"
210220
211- obs = self ._obs_step (
212- self ._get_in_obs (self .current_tensordict .get (self .out_key )), a
213- )
221+ obs = self ._obs_step (self ._get_in_obs (tensordict .get (self ._out_key )), a )
214222 tensordict = tensordict .select () # empty tensordict
223+
215224 tensordict .set ("next_" + self .out_key , self ._get_out_obs (obs ))
225+ tensordict .set ("next_" + self ._out_key , self ._get_out_obs (obs ))
226+
216227 done = torch .isclose (obs , torch .ones_like (obs ) * (self .counter + 1 ))
217228 reward = done .any (- 1 ).unsqueeze (- 1 )
218229 done = done .all (- 1 ).unsqueeze (- 1 )
@@ -251,6 +262,10 @@ class DiscreteActionConvMockEnv(DiscreteActionVecMockEnv):
251262 from_pixels = True
252263
253264 out_key = "pixels"
265+ _out_key = "pixels_orig"
266+ input_spec = CompositeSpec (
267+ ** {_out_key : observation_spec ["next_pixels" ], "action" : action_spec }
268+ )
254269
255270 def _get_out_obs (self , obs ):
256271 obs = torch .diag_embed (obs , 0 , - 2 , - 1 ).unsqueeze (0 )
@@ -287,6 +302,10 @@ class ContinuousActionConvMockEnv(ContinuousActionVecMockEnv):
287302 from_pixels = True
288303
289304 out_key = "pixels"
305+ _out_key = "pixels_orig"
306+ input_spec = CompositeSpec (
307+ ** {_out_key : observation_spec ["next_pixels" ], "action" : action_spec }
308+ )
290309
291310 def _get_out_obs (self , obs ):
292311 obs = torch .diag_embed (obs , 0 , - 2 , - 1 ).unsqueeze (0 )
0 commit comments