22
33import torch
44from tensordict .tensordict import TensorDict , TensorDictBase
5+
56from torchrl .data import CompositeSpec , DEVICE_TYPING , UnboundedContinuousTensorSpec
67from torchrl .envs .common import _EnvWrapper , EnvBase
78from torchrl .envs .libs .gym import _gym_to_torchrl_spec_transform
@@ -107,25 +108,6 @@ def __init__(
107108 raise TypeError ("Env device is different from vmas device" )
108109 kwargs ["device" ] = str (env .device )
109110 super ().__init__ (** kwargs )
110- if len (self .batch_size ) == 0 :
111- # Batch size not set
112- self .batch_size = torch .Size ((self .num_envs ,))
113- elif len (self .batch_size ) == 1 :
114- # Batch size is set
115- if not self .batch_size [0 ] == self .num_envs :
116- raise TypeError (
117- "Batch size used in constructor does not match vmas batch size."
118- )
119- else :
120- raise TypeError (
121- "Batch size used in constructor is not compatible with vmas."
122- )
123- self .batch_size = torch .Size ([self .n_agents , * self .batch_size ])
124- self .input_spec = self .input_spec .expand (self .batch_size )
125- self .observation_spec = self .observation_spec .expand (self .batch_size )
126- self .reward_spec = self .reward_spec .expand (
127- [* self .batch_size , * self .reward_spec .shape ]
128- )
129111
130112 @property
131113 def lib (self ):
@@ -144,6 +126,22 @@ def _build_env(
144126 if self .from_pixels :
145127 raise NotImplementedError ("vmas rendering not yet implemented" )
146128
129+ # Adjust batch size
130+ if len (self .batch_size ) == 0 :
131+ # Batch size not set
132+ self .batch_size = torch .Size ((env .num_envs ,))
133+ elif len (self .batch_size ) == 1 :
134+ # Batch size is set
135+ if not self .batch_size [0 ] == env .num_envs :
136+ raise TypeError (
137+ "Batch size used in constructor does not match vmas batch size."
138+ )
139+ else :
140+ raise TypeError (
141+ "Batch size used in constructor is not compatible with vmas."
142+ )
143+ self .batch_size = torch .Size ([env .n_agents , * self .batch_size ])
144+
147145 return env
148146
149147 def _make_specs (
0 commit comments