@@ -46,6 +46,8 @@ def _validate(self):
4646 log .debug (f'================ len(episodes): { len (episodes )} ==================' )
4747
4848 for runtime in self ._runtime .task_runtime_manager .episodes :
49+ if len (runtime .robots ) == 0 :
50+ return
4951 if len (runtime .robots ) != 1 :
5052 raise ValueError (f'Only support single agent now, but episode requires { len (runtime .robots )} agents' )
5153 if robot_name is None :
@@ -76,6 +78,7 @@ def reset(self, *, seed=None, options=None) -> tuple[gym.Space, dict[str, Any]]:
7678 info (dictionary): Contains the key `task_runtime` if there is an unfinished task
7779 """
7880 info = {}
81+ obs = {}
7982
8083 origin_obs , task_runtime = self .runner .reset (self ._current_task_name )
8184 if task_runtime is None :
@@ -84,7 +87,8 @@ def reset(self, *, seed=None, options=None) -> tuple[gym.Space, dict[str, Any]]:
8487
8588 self ._current_task_name = task_runtime .name
8689 info [Env .RESET_INFO_TASK_RUNTIME ] = task_runtime
87- obs = origin_obs [task_runtime .name ][self ._robot_name ]
90+ if self ._robot_name :
91+ obs = origin_obs [task_runtime .name ][self ._robot_name ]
8892
8993 return obs , info
9094
@@ -124,7 +128,8 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]:
124128 if rewards [self ._current_task_name ] != - 1 :
125129 reward = rewards [self ._current_task_name ]
126130
127- obs = origin_obs [self ._current_task_name ][self ._robot_name ]
131+ if self ._robot_name :
132+ obs = origin_obs [self ._current_task_name ][self ._robot_name ]
128133 terminated = terminated_status [self ._current_task_name ]
129134
130135 return obs , reward , terminated , truncated , info
@@ -160,6 +165,8 @@ def get_observations(self) -> dict[Any, Any] | Any:
160165 return {}
161166
162167 _obs = self ._runner .get_obs ()
168+ if self ._robot_name is None :
169+ return {}
163170 return _obs [self ._current_task_name ][self ._robot_name ]
164171
165172 def render (self , mode = 'human' ):
0 commit comments