Skip to content

Commit 0b501db

Browse files
committed
[fix] Fix gym_env error when loading scene without robots
1 parent 7159d00 commit 0b501db

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

grutopia/core/gym_env.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def _validate(self):
4646
log.debug(f'================ len(episodes): {len(episodes)} ==================')
4747

4848
for runtime in self._runtime.task_runtime_manager.episodes:
49+
if len(runtime.robots) == 0:
50+
return
4951
if len(runtime.robots) != 1:
5052
raise ValueError(f'Only support single agent now, but episode requires {len(runtime.robots)} agents')
5153
if robot_name is None:
@@ -76,6 +78,7 @@ def reset(self, *, seed=None, options=None) -> tuple[gym.Space, dict[str, Any]]:
7678
info (dictionary): Contains the key `task_runtime` if there is an unfinished task
7779
"""
7880
info = {}
81+
obs = {}
7982

8083
origin_obs, task_runtime = self.runner.reset(self._current_task_name)
8184
if task_runtime is None:
@@ -84,7 +87,8 @@ def reset(self, *, seed=None, options=None) -> tuple[gym.Space, dict[str, Any]]:
8487

8588
self._current_task_name = task_runtime.name
8689
info[Env.RESET_INFO_TASK_RUNTIME] = task_runtime
87-
obs = origin_obs[task_runtime.name][self._robot_name]
90+
if self._robot_name:
91+
obs = origin_obs[task_runtime.name][self._robot_name]
8892

8993
return obs, info
9094

@@ -124,7 +128,8 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]:
124128
if rewards[self._current_task_name] != -1:
125129
reward = rewards[self._current_task_name]
126130

127-
obs = origin_obs[self._current_task_name][self._robot_name]
131+
if self._robot_name:
132+
obs = origin_obs[self._current_task_name][self._robot_name]
128133
terminated = terminated_status[self._current_task_name]
129134

130135
return obs, reward, terminated, truncated, info
@@ -160,6 +165,8 @@ def get_observations(self) -> dict[Any, Any] | Any:
160165
return {}
161166

162167
_obs = self._runner.get_obs()
168+
if self._robot_name is None:
169+
return {}
163170
return _obs[self._current_task_name][self._robot_name]
164171

165172
def render(self, mode='human'):

tests/e2e_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def test_h1_locomotion():
4949
common_body(start_command)
5050

5151

52+
@pytest.mark.P0
53+
def test_load_scene_without_robot():
54+
start_command = 'python ./tests/load_scene_without_robot.py'
55+
common_body(start_command)
56+
57+
5258
@pytest.mark.P0
5359
def test_rep_camera_pointcloud():
5460
start_command = 'python ./tests/rep_camera_pointcloud.py'

tests/load_scene_without_robot.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
def main():
2+
from grutopia.core.config import Config, SimConfig
3+
from grutopia.core.gym_env import Env
4+
from grutopia.core.runtime import SimulatorRuntime
5+
from grutopia.core.util import has_display
6+
from grutopia.macros import gm
7+
from grutopia_extension import import_extensions
8+
from grutopia_extension.configs.tasks import (
9+
SingleInferenceEpisodeCfg,
10+
SingleInferenceTaskCfg,
11+
)
12+
13+
headless = False
14+
if not has_display():
15+
headless = True
16+
17+
config = Config(
18+
simulator=SimConfig(physics_dt=1 / 240, rendering_dt=1 / 240, use_fabric=False),
19+
task_config=SingleInferenceTaskCfg(
20+
episodes=[
21+
SingleInferenceEpisodeCfg(
22+
scene_asset_path=gm.ASSET_PATH + '/scenes/empty.usd',
23+
scene_scale=(0.01, 0.01, 0.01),
24+
robots=[],
25+
),
26+
],
27+
),
28+
)
29+
30+
print(config.model_dump_json(indent=4))
31+
32+
sim_runtime = SimulatorRuntime(config_class=config, headless=headless, native=headless)
33+
34+
import_extensions()
35+
36+
env = Env(sim_runtime)
37+
obs, _ = env.reset()
38+
print(f'========INIT OBS{obs}=============')
39+
40+
i = 0
41+
42+
while env.simulation_app.is_running():
43+
i += 1
44+
obs, _, terminated, _, _ = env.step(action={})
45+
46+
if i == 2000:
47+
break
48+
49+
env.close()
50+
51+
52+
if __name__ == '__main__':
53+
try:
54+
main()
55+
except Exception as e:
56+
print(f'exception is {e}')
57+
import sys
58+
import traceback
59+
60+
traceback.print_exc()
61+
sys.exit(1)

0 commit comments

Comments
 (0)