@@ -49,12 +49,14 @@ def __init__(
4949 batch_size : torch .Size ,
5050 env_str : str ,
5151 device : torch .device ,
52+ batch_locked : bool = True ,
5253 ):
5354 self .tensordict = tensordict
5455 self .specs = specs
5556 self .batch_size = batch_size
5657 self .env_str = env_str
5758 self .device = device
59+ self .batch_locked = batch_locked
5860
5961 @staticmethod
6062 def build_metadata_from_env (env ) -> EnvMetaData :
@@ -64,19 +66,27 @@ def build_metadata_from_env(env) -> EnvMetaData:
6466 batch_size = env .batch_size
6567 env_str = str (env )
6668 device = env .device
67- return EnvMetaData (tensordict , specs , batch_size , env_str , device )
69+ batch_locked = env .batch_locked
70+ return EnvMetaData (tensordict , specs , batch_size , env_str , device , batch_locked )
6871
6972 def expand (self , * size : int ) -> EnvMetaData :
7073 tensordict = self .tensordict .expand (* size ).to_tensordict ()
7174 batch_size = torch .Size ([* size ])
7275 return EnvMetaData (
73- tensordict , self .specs , batch_size , self .env_str , self .device
76+ tensordict ,
77+ self .specs ,
78+ batch_size ,
79+ self .env_str ,
80+ self .device ,
81+ self .batch_locked ,
7482 )
7583
7684 def to (self , device : DEVICE_TYPING ) -> EnvMetaData :
7785 tensordict = self .tensordict .to (device )
7886 specs = self .specs .to (device )
79- return EnvMetaData (tensordict , specs , self .batch_size , self .env_str , device )
87+ return EnvMetaData (
88+ tensordict , specs , self .batch_size , self .env_str , device , self .batch_locked
89+ )
8090
8191 def __setstate__ (self , state ):
8292 state ["tensordict" ] = state ["tensordict" ].to_tensordict ().to (state ["device" ])
@@ -218,10 +228,24 @@ def __init__(
218228 self .batch_size = torch .Size ([])
219229
220230 @classmethod
221- def __new__ (cls , * args , ** kwargs ):
231+ def __new__ (cls , * args , _batch_locked = True , ** kwargs ):
222232 cls ._inplace_update = True
233+ cls ._batch_locked = _batch_locked
223234 return super ().__new__ (cls )
224235
236+ @property
237+ def batch_locked (self ) -> bool :
238+ """
239+ Whether the environnement can be used with a batch size different from the one it was initialized with or not.
240+ If True, the env needs to be used with a tensordict having the same batch size as the env.
241+ batch_locked is an immutable property.
242+ """
243+ return self ._batch_locked
244+
245+ @batch_locked .setter
246+ def batch_locked (self , value : bool ) -> None :
247+ raise RuntimeError ("batch_locked is a read-only property" )
248+
225249 @property
226250 def action_spec (self ) -> TensorSpec :
227251 return self ._action_spec
@@ -272,6 +296,8 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
272296 """
273297
274298 # sanity check
299+ self ._assert_tensordict_shape (tensordict )
300+
275301 if tensordict .get ("action" ).dtype is not self .action_spec .dtype :
276302 raise TypeError (
277303 f"expected action.dtype to be { self .action_spec .dtype } "
@@ -408,7 +434,9 @@ def set_state(self):
408434 raise NotImplementedError
409435
410436 def _assert_tensordict_shape (self , tensordict : TensorDictBase ) -> None :
411- if tensordict .batch_size != self .batch_size :
437+ if tensordict .batch_size != self .batch_size and (
438+ self .batch_locked or self .batch_size != torch .Size ([])
439+ ):
412440 raise RuntimeError (
413441 f"Expected a tensordict with shape==env.shape, "
414442 f"got { tensordict .batch_size } and { self .batch_size } "
@@ -531,7 +559,9 @@ def policy(td):
531559 else :
532560 raise Exception ("reset env before calling rollout!" )
533561
534- out_td = torch .stack (tensordicts , len (self .batch_size ))
562+ batch_size = self .batch_size if tensordict is None else tensordict .batch_size
563+
564+ out_td = torch .stack (tensordicts , len (batch_size ))
535565 if return_contiguous :
536566 return out_td .contiguous ()
537567 return out_td
0 commit comments