Skip to content

Commit 2e98102

Browse files
puyuan1996puyuan
andauthored
fix(pu): adapt atari and dmc2gym env to support shared_memory (#345)
* fix(pu): fix atari and dmc2gym env to support shared_memory * tmp * fix(pu): fix frame_stack_num default cfg in atari env --------- Co-authored-by: puyuan <[email protected]>
1 parent 68a0c38 commit 2e98102

File tree

4 files changed

+66
-18
lines changed

4 files changed

+66
-18
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,3 +1449,5 @@ events.*
14491449
# pooltool-specific stuff
14501450
!/assets/pooltool/**
14511451
lzero/mcts/ctree/ctree_alphazero/pybind11
1452+
1453+
zoo/jericho/envs/z-machine-games-master

zoo/atari/config/atari_muzero_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
update_per_collect = None
1515
replay_ratio = 0.25
1616
batch_size = 256
17-
max_env_step = int(2e5)
17+
max_env_step = int(5e5)
1818
reanalyze_ratio = 0.
1919

2020
# =========== for debug ===========
@@ -33,13 +33,13 @@
3333
env=dict(
3434
stop_value=int(1e6),
3535
env_id=env_id,
36-
observation_shape=(4, 64, 64), # (4, 96, 96)
36+
observation_shape=(4, 64, 64),
3737
frame_stack_num=4,
3838
gray_scale=True,
3939
collector_env_num=collector_env_num,
4040
evaluator_env_num=evaluator_env_num,
4141
n_evaluator_episode=evaluator_env_num,
42-
manager=dict(shared_memory=False, ),
42+
manager=dict(shared_memory=True, ),
4343
# TODO: debug
4444
# collect_max_episode_steps=int(50),
4545
# eval_max_episode_steps=int(50),
@@ -48,17 +48,16 @@
4848
analysis_sim_norm=False,
4949
cal_dormant_ratio=False,
5050
model=dict(
51-
observation_shape=(4, 64, 64), # (4, 96, 96)
52-
image_channel=1,
51+
observation_shape=(4, 64, 64),
5352
frame_stack_num=4,
53+
image_channel=1,
5454
gray_scale=True,
5555
action_space_size=action_space_size,
5656
downsample=True,
57-
self_supervised_learning_loss=True, # default is False
57+
self_supervised_learning_loss=True,
5858
discrete_action_encoding_type='one_hot',
5959
norm_type='BN',
6060
use_sim_norm=True,
61-
use_sim_norm_kl_loss=False,
6261
model_type='conv'
6362
),
6463
cuda=True,

zoo/atari/envs/atari_lightzero_env.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from ditk import logging
33
from typing import List
44

5-
# import gymnasium as gym
65
import gym
76
import numpy as np
87
from ding.envs import BaseEnv, BaseEnvTimestep
@@ -50,6 +49,9 @@ class AtariEnvLightZero(BaseEnv):
5049
replay_path=None,
5150
# (bool) If set to True, the game screen is converted to grayscale, reducing the complexity of the observation space.
5251
gray_scale=True,
52+
# (int) Specifies the number of consecutive frames to stack after collecting environment data.
53+
# The stacking process is applied within the collector and evaluator modules.
54+
frame_stack_num=1,
5355
# (int) The number of frames to skip between each action. Higher values result in faster simulation.
5456
frame_skip=4,
5557
# (bool) If True, the game ends when the agent loses a life, otherwise, the game only ends when all lives are lost.
@@ -112,7 +114,28 @@ def reset(self) -> dict:
112114
if not self._init_flag:
113115
# Create and return the wrapped environment for Atari LightZero.
114116
self._env = wrap_lightzero(self.cfg, episode_life=self.cfg.episode_life, clip_rewards=self.cfg.clip_rewards)
115-
self._observation_space = self._env.env.observation_space
117+
118+
observation_space_before_stack = (
119+
int(self.cfg.observation_shape[0] / self.cfg.frame_stack_num),
120+
self.cfg.observation_shape[1],
121+
self.cfg.observation_shape[2]
122+
)
123+
124+
self._observation_space = gym.spaces.Dict({
125+
'observation': gym.spaces.Box(
126+
low=0, high=1, shape=observation_space_before_stack, dtype=np.float32
127+
),
128+
'action_mask': gym.spaces.Box(
129+
low=0, high=1, shape=(self._env.env.action_space.n,), dtype=np.int8
130+
),
131+
'to_play': gym.spaces.Box(
132+
low=-1, high=2, shape=(), dtype=np.int8
133+
),
134+
'timestep': gym.spaces.Box(
135+
low=0, high=self.cfg.collect_max_episode_steps, shape=(), dtype=np.int32
136+
),
137+
})
138+
116139
self._action_space = self._env.env.action_space
117140
self._reward_space = gym.spaces.Box(
118141
low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1,), dtype=np.float32
@@ -174,8 +197,10 @@ def observe(self) -> dict:
174197
observation = np.transpose(observation, (2, 0, 1))
175198

176199
action_mask = np.ones(self._action_space.n, 'int8')
177-
return {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep}
178200

201+
return {'observation': observation, 'action_mask': action_mask, 'to_play': np.array(-1), 'timestep': np.array(self._timestep)}
202+
203+
179204
@property
180205
def legal_actions(self):
181206
return np.arange(self._action_space.n)

zoo/dmc2gym/envs/dmc2gym_lightzero_env.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from typing import Optional
66

77
import dmc2gym
8-
import gymnasium as gym
8+
# import gymnasium as gym
9+
import gym
910
import matplotlib.pyplot as plt
1011
import numpy as np
1112
from ding.envs import BaseEnv, BaseEnvTimestep, WarpFrameWrapper, ScaledFloatFrameWrapper, \
@@ -255,7 +256,7 @@ def __init__(self, cfg: dict = {}) -> None:
255256
self._init_flag = False
256257
self._replay_path = self._cfg.replay_path
257258

258-
self._observation_space = dmc2gym_env_info[self._cfg.domain_name][self._cfg.task_name]["observation_space"](
259+
self._observation_space_origin = dmc2gym_env_info[self._cfg.domain_name][self._cfg.task_name]["observation_space"](
259260
from_pixels=self._cfg["from_pixels"],
260261
height=self._cfg["height"],
261262
width=self._cfg["width"],
@@ -300,7 +301,28 @@ def reset(self) -> Dict[str, np.ndarray]:
300301
self._env = FrameStackWrapper(self._env, self._cfg['frame_stack'])
301302

302303
# set the obs, action space of wrapped env
303-
self._observation_space = self._env.observation_space
304+
self._observation_space = gym.spaces.Dict({
305+
'observation': self._observation_space_origin,
306+
'action_mask': gym.spaces.Box(
307+
low=0,
308+
high=1,
309+
shape=(1,),
310+
dtype=np.int8
311+
),
312+
'to_play': gym.spaces.Box(
313+
low=-1,
314+
high=2,
315+
shape=(),
316+
dtype=np.int8
317+
),
318+
'timestep': gym.spaces.Box(
319+
low=0,
320+
high=self._cfg.collect_max_episode_steps,
321+
shape=(),
322+
dtype=np.int32
323+
),
324+
})
325+
304326
self._action_space = self._env.action_space
305327

306328
if self._replay_path is not None:
@@ -330,13 +352,13 @@ def reset(self) -> Dict[str, np.ndarray]:
330352
obs = obs['state']
331353

332354
obs = to_ndarray(obs).astype(np.float32)
333-
action_mask = None
334355

335356
self._timestep = 0
336357
if self._save_replay_gif:
337358
self._frames = []
338-
339-
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep}
359+
360+
action_mask = -1
361+
obs = {'observation': obs, 'action_mask': np.array(action_mask), 'to_play': np.array(-1), 'timestep': np.array(self._timestep)}
340362

341363
return obs
342364

@@ -406,8 +428,8 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
406428
print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!')
407429
self._save_replay_count += 1
408430

409-
action_mask = None
410-
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep}
431+
action_mask = -1
432+
obs = {'observation': obs, 'action_mask': np.array(action_mask), 'to_play': np.array(-1), 'timestep': np.array(self._timestep)}
411433

412434
return BaseEnvTimestep(obs, rew, done, info)
413435

0 commit comments

Comments
 (0)