@@ -121,7 +121,7 @@ def __new__(
121121 action_spec = NdUnboundedContinuousTensorSpec ((1 ,))
122122 if observation_spec is None :
123123 observation_spec = CompositeSpec (
124- next_observation = NdUnboundedContinuousTensorSpec ((1 ,))
124+ observation = NdUnboundedContinuousTensorSpec ((1 ,))
125125 )
126126 if reward_spec is None :
127127 reward_spec = NdUnboundedContinuousTensorSpec ((1 ,))
@@ -152,19 +152,17 @@ def _step(self, tensordict):
152152 )
153153 done = self .counter >= self .max_val
154154 done = torch .tensor ([done ], dtype = torch .bool , device = self .device )
155- return TensorDict (
156- {"reward" : n , "done" : done , "next_observation" : n .clone ()}, []
157- )
155+ return TensorDict ({"reward" : n , "done" : done , "observation" : n .clone ()}, [])
158156
159- def _reset (self , tensordict : TensorDictBase , ** kwargs ) -> TensorDictBase :
157+ def _reset (self , tensordict : TensorDictBase = None , ** kwargs ) -> TensorDictBase :
160158 self .max_val = max (self .counter + 100 , self .counter * 2 )
161159
162160 n = torch .tensor (
163161 [self .counter ], device = self .device , dtype = torch .get_default_dtype ()
164162 )
165163 done = self .counter >= self .max_val
166164 done = torch .tensor ([done ], dtype = torch .bool , device = self .device )
167- return TensorDict ({"done" : done , "next_observation " : n }, [])
165+ return TensorDict ({"done" : done , "observation " : n }, [])
168166
169167 def rand_step (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDictBase :
170168 return self .step (tensordict )
@@ -192,7 +190,7 @@ def __new__(
192190 )
193191 if observation_spec is None :
194192 observation_spec = CompositeSpec (
195- next_observation = NdUnboundedContinuousTensorSpec ((1 ,))
193+ observation = NdUnboundedContinuousTensorSpec ((1 ,))
196194 )
197195 if reward_spec is None :
198196 reward_spec = NdUnboundedContinuousTensorSpec ((1 ,))
@@ -226,7 +224,7 @@ def _step(self, tensordict):
226224 )
227225
228226 return TensorDict (
229- {"reward" : n , "done" : done , "next_observation " : n },
227+ {"reward" : n , "done" : done , "observation " : n },
230228 tensordict .batch_size ,
231229 device = self .device ,
232230 )
@@ -247,7 +245,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
247245 done = torch .full (batch_size , done , dtype = torch .bool , device = self .device )
248246
249247 return TensorDict (
250- {"reward" : n , "done" : done , "next_observation " : n },
248+ {"reward" : n , "done" : done , "observation " : n },
251249 batch_size ,
252250 device = self .device ,
253251 )
@@ -287,10 +285,8 @@ def __new__(
287285 if observation_spec is None :
288286 cls .out_key = "observation"
289287 observation_spec = CompositeSpec (
290- next_observation = NdUnboundedContinuousTensorSpec (
291- shape = torch .Size ([size ])
292- ),
293- next_observation_orig = NdUnboundedContinuousTensorSpec (
288+ observation = NdUnboundedContinuousTensorSpec (shape = torch .Size ([size ])),
289+ observation_orig = NdUnboundedContinuousTensorSpec (
294290 shape = torch .Size ([size ])
295291 ),
296292 )
@@ -308,7 +304,7 @@ def __new__(
308304 cls ._out_key = "observation_orig"
309305 input_spec = CompositeSpec (
310306 ** {
311- cls ._out_key : observation_spec ["next_observation " ],
307+ cls ._out_key : observation_spec ["observation " ],
312308 "action" : action_spec ,
313309 }
314310 )
@@ -325,15 +321,13 @@ def _get_in_obs(self, obs):
325321 def _get_out_obs (self , obs ):
326322 return obs
327323
328- def _reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
324+ def _reset (self , tensordict : TensorDictBase = None ) -> TensorDictBase :
329325 self .counter += 1
330326 state = torch .zeros (self .size ) + self .counter
331327 if tensordict is None :
332328 tensordict = TensorDict ({}, self .batch_size , device = self .device )
333- tensordict = tensordict .select ().set (
334- "next_" + self .out_key , self ._get_out_obs (state )
335- )
336- tensordict = tensordict .set ("next_" + self ._out_key , self ._get_out_obs (state ))
329+ tensordict = tensordict .select ().set (self .out_key , self ._get_out_obs (state ))
330+ tensordict = tensordict .set (self ._out_key , self ._get_out_obs (state ))
337331 tensordict .set ("done" , torch .zeros (* tensordict .shape , 1 , dtype = torch .bool ))
338332 return tensordict
339333
@@ -351,8 +345,8 @@ def _step(
351345 obs = self ._get_in_obs (tensordict .get (self ._out_key )) + a / self .maxstep
352346 tensordict = tensordict .select () # empty tensordict
353347
354- tensordict .set ("next_" + self .out_key , self ._get_out_obs (obs ))
355- tensordict .set ("next_" + self ._out_key , self ._get_out_obs (obs ))
348+ tensordict .set (self .out_key , self ._get_out_obs (obs ))
349+ tensordict .set (self ._out_key , self ._get_out_obs (obs ))
356350
357351 done = torch .isclose (obs , torch .ones_like (obs ) * (self .counter + 1 ))
358352 reward = done .any (- 1 ).unsqueeze (- 1 )
@@ -379,10 +373,8 @@ def __new__(
379373 if observation_spec is None :
380374 cls .out_key = "observation"
381375 observation_spec = CompositeSpec (
382- next_observation = NdUnboundedContinuousTensorSpec (
383- shape = torch .Size ([size ])
384- ),
385- next_observation_orig = NdUnboundedContinuousTensorSpec (
376+ observation = NdUnboundedContinuousTensorSpec (shape = torch .Size ([size ])),
377+ observation_orig = NdUnboundedContinuousTensorSpec (
386378 shape = torch .Size ([size ])
387379 ),
388380 )
@@ -395,7 +387,7 @@ def __new__(
395387 cls ._out_key = "observation_orig"
396388 input_spec = CompositeSpec (
397389 ** {
398- cls ._out_key : observation_spec ["next_observation " ],
390+ cls ._out_key : observation_spec ["observation " ],
399391 "action" : action_spec ,
400392 }
401393 )
@@ -436,8 +428,8 @@ def _step(
436428 obs = self ._obs_step (self ._get_in_obs (tensordict .get (self ._out_key )), a )
437429 tensordict = tensordict .select () # empty tensordict
438430
439- tensordict .set ("next_" + self .out_key , self ._get_out_obs (obs ))
440- tensordict .set ("next_" + self ._out_key , self ._get_out_obs (obs ))
431+ tensordict .set (self .out_key , self ._get_out_obs (obs ))
432+ tensordict .set (self ._out_key , self ._get_out_obs (obs ))
441433
442434 done = torch .isclose (obs , torch .ones_like (obs ) * (self .counter + 1 ))
443435 reward = done .any (- 1 ).unsqueeze (- 1 )
@@ -483,10 +475,8 @@ def __new__(
483475 if observation_spec is None :
484476 cls .out_key = "pixels"
485477 observation_spec = CompositeSpec (
486- next_pixels = NdUnboundedContinuousTensorSpec (
487- shape = torch .Size ([1 , 7 , 7 ])
488- ),
489- next_pixels_orig = NdUnboundedContinuousTensorSpec (
478+ pixels = NdUnboundedContinuousTensorSpec (shape = torch .Size ([1 , 7 , 7 ])),
479+ pixels_orig = NdUnboundedContinuousTensorSpec (
490480 shape = torch .Size ([1 , 7 , 7 ])
491481 ),
492482 )
@@ -499,7 +489,7 @@ def __new__(
499489 cls ._out_key = "pixels_orig"
500490 input_spec = CompositeSpec (
501491 ** {
502- cls ._out_key : observation_spec ["next_pixels_orig " ],
492+ cls ._out_key : observation_spec ["pixels_orig " ],
503493 "action" : action_spec ,
504494 }
505495 )
@@ -537,10 +527,8 @@ def __new__(
537527 if observation_spec is None :
538528 cls .out_key = "pixels"
539529 observation_spec = CompositeSpec (
540- next_pixels = NdUnboundedContinuousTensorSpec (
541- shape = torch .Size ([7 , 7 , 3 ])
542- ),
543- next_pixels_orig = NdUnboundedContinuousTensorSpec (
530+ pixels = NdUnboundedContinuousTensorSpec (shape = torch .Size ([7 , 7 , 3 ])),
531+ pixels_orig = NdUnboundedContinuousTensorSpec (
544532 shape = torch .Size ([7 , 7 , 3 ])
545533 ),
546534 )
@@ -555,7 +543,7 @@ def __new__(
555543 cls ._out_key = "pixels_orig"
556544 input_spec = CompositeSpec (
557545 ** {
558- cls ._out_key : observation_spec ["next_pixels_orig " ],
546+ cls ._out_key : observation_spec ["pixels_orig " ],
559547 "action" : action_spec ,
560548 }
561549 )
@@ -599,10 +587,8 @@ def __new__(
599587 if observation_spec is None :
600588 cls .out_key = "pixels"
601589 observation_spec = CompositeSpec (
602- next_pixels = NdUnboundedContinuousTensorSpec (
603- shape = torch .Size (pixel_shape )
604- ),
605- next_pixels_orig = NdUnboundedContinuousTensorSpec (
590+ pixels = NdUnboundedContinuousTensorSpec (shape = torch .Size (pixel_shape )),
591+ pixels_orig = NdUnboundedContinuousTensorSpec (
606592 shape = torch .Size (pixel_shape )
607593 ),
608594 )
@@ -615,7 +601,7 @@ def __new__(
615601 if input_spec is None :
616602 cls ._out_key = "pixels_orig"
617603 input_spec = CompositeSpec (
618- ** {cls ._out_key : observation_spec ["next_pixels " ], "action" : action_spec }
604+ ** {cls ._out_key : observation_spec ["pixels " ], "action" : action_spec }
619605 )
620606 return super ().__new__ (
621607 * args ,
@@ -650,10 +636,8 @@ def __new__(
650636 if observation_spec is None :
651637 cls .out_key = "pixels"
652638 observation_spec = CompositeSpec (
653- next_pixels = NdUnboundedContinuousTensorSpec (
654- shape = torch .Size ([7 , 7 , 3 ])
655- ),
656- next_pixels_orig = NdUnboundedContinuousTensorSpec (
639+ pixels = NdUnboundedContinuousTensorSpec (shape = torch .Size ([7 , 7 , 3 ])),
640+ pixels_orig = NdUnboundedContinuousTensorSpec (
657641 shape = torch .Size ([7 , 7 , 3 ])
658642 ),
659643 )
@@ -714,7 +698,7 @@ def __init__(
714698 batch_size = batch_size ,
715699 )
716700 self .observation_spec = CompositeSpec (
717- next_hidden_observation = NdUnboundedContinuousTensorSpec ((4 ,))
701+ hidden_observation = NdUnboundedContinuousTensorSpec ((4 ,))
718702 )
719703 self .input_spec = CompositeSpec (
720704 hidden_observation = NdUnboundedContinuousTensorSpec ((4 ,)),
@@ -728,9 +712,6 @@ def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
728712 "hidden_observation" : self .input_spec ["hidden_observation" ].rand (
729713 self .batch_size
730714 ),
731- "next_hidden_observation" : self .observation_spec [
732- "next_hidden_observation"
733- ].rand (self .batch_size ),
734715 },
735716 batch_size = self .batch_size ,
736717 device = self .device ,
0 commit comments