Skip to content

Commit 4d9ac1d

Browse files
authored
add mock gym.py update continuous_wrapper add example test (#685)
* add mock gym.py update continuous_wrapper add example test * yapf * update build.sh, update gym.py * update build.sh * update build.sh * remove stick equality to prevent infinite loop * add comment * add copyright * delete paddle_speed_test.py * delete torch_speed_test.py * add comment * yapf * update comment
1 parent c564af9 commit 4d9ac1d

File tree

4 files changed

+245
-11
lines changed

4 files changed

+245
-11
lines changed

.teamcity/build.sh

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
set -ex
1617

1718
function init() {
1819
RED='\033[0;31m'
@@ -26,6 +27,24 @@ function init() {
2627
export LD_LIBRARY_PATH="/usr/local/TensorRT-6.0.1.5/lib:$LD_LIBRARY_PATH"
2728
}
2829

30+
function run_example_test {
31+
for exp in QuickStart DQN DQN_variant PPO SAC TD3 OAC DDPG
32+
do
33+
cp parl/tests/gym.py examples/${exp}/
34+
done
35+
36+
python examples/QuickStart/train.py
37+
python examples/DQN/train.py
38+
python examples/DQN_variant/train.py --train_total_steps 5000 --algo DQN --env PongNoFrameskip-v4
39+
python examples/DQN_variant/train.py --train_total_steps 5000 --algo DDQN --env PongNoFrameskip-v4
40+
python examples/DQN_variant/train.py --train_total_steps 5000 --dueling True --env PongNoFrameskip-v4
41+
python examples/PPO/train.py --train_total_steps 5000 --env HalfCheetah-v1
42+
python examples/SAC/train.py --train_total_steps 5000 --env HalfCheetah-v1
43+
python examples/TD3/train.py --train_total_steps 5000 --env HalfCheetah-v1
44+
python examples/OAC/train.py --train_total_steps 5000 --env HalfCheetah-v1
45+
python examples/DDPG/train.py --train_total_steps 5000 --env HalfCheetah-v1
46+
}
47+
2948
function print_usage() {
3049
echo -e "\n${RED}Usage${NONE}:
3150
${BOLD}$0${NONE} [OPTION]"
@@ -143,13 +162,6 @@ function run_test_with_fluid() {
143162
done
144163
}
145164

146-
function run_cartpole_test {
147-
for exp in QuickStart DQN
148-
do
149-
python examples/${exp}/train.py
150-
done
151-
}
152-
153165
function run_import_test {
154166
export CUDA_VISIBLE_DEVICES=""
155167

@@ -237,7 +249,8 @@ function main() {
237249
pip install -r .teamcity/requirements.txt
238250
pip install /data/paddle_package/paddlepaddle_gpu-2.1.0.post101-cp38-cp38-linux_x86_64.whl
239251
run_test_with_gpu $env
240-
run_cartpole_test $env
252+
pip install tqdm # for example test
253+
run_example_test $env
241254

242255
run_test_with_fluid
243256
############

parl/env/atari_wrappers.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
1-
# Third party code
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
28
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# Third party code
316
# The following code are copied or modified from:
417
# https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/atari_wrappers.py
518

@@ -282,7 +295,9 @@ def step(self, action):
282295

283296
def reset(self, **kwargs):
284297
obs = self._env.reset(**kwargs)
285-
if self._get_curr_episode() == self._end_episode:
298+
# During the noop reset in NoopResetEnv, env may be reset multiple times(may occur in mock env,
299+
# almost impossible in atari env), so the == condition may never be met. >= can avoid infinite loop.
300+
if self._get_curr_episode() >= self._end_episode:
286301
self._was_real_done = True
287302
self._eval_rewards = \
288303
self._monitor.get_episode_rewards()[-self._eval_episodes:]

parl/env/continuous_wrappers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@ def __init__(self, env):
2323
[low_bound, high_bound].
2424
"""
2525
gym.Wrapper.__init__(self, env)
26-
assert isinstance(self.env.action_space, gym.spaces.Box)
26+
assert hasattr(
27+
self.env.action_space,
28+
'low'), 'action space should be instance of gym.spaces.Box'
29+
assert hasattr(
30+
self.env.action_space,
31+
'high'), 'action space should be instance of gym.spaces.Box'
2732
self.low_bound = self.env.action_space.low[0]
2833
self.high_bound = self.env.action_space.high[0]
2934
assert self.high_bound > self.low_bound

parl/tests/gym.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# mock gym environment
15+
import numpy as np
16+
from random import random
17+
18+
19+
def make(env_name):
20+
print('>>>>>>>>> you are testing mock gym env: ', env_name)
21+
if env_name == 'CartPole-v0':
22+
return CartPoleEnv()
23+
elif env_name == 'PongNoFrameskip-v4':
24+
return PongEnv()
25+
elif env_name == 'HalfCheetah-v1':
26+
return HalfCheetahEnv()
27+
else:
28+
raise NotImplementedError(
29+
'Mock env not defined, please check your env name')
30+
31+
32+
# mock Box
33+
class Box(object):
34+
def __init__(self, low, high, shape, dtype):
35+
self.low = low
36+
self.high = high
37+
self.shape = shape
38+
self.dtype = dtype
39+
40+
41+
# mock gym.Wrapper
42+
class Wrapper(object):
43+
def __init__(self, env):
44+
self.env = env
45+
46+
def __getattr__(self, name):
47+
if name.startswith('_'):
48+
raise AttributeError(
49+
"attempted to get missing private attribute '{}'".format(name))
50+
return getattr(self.env, name)
51+
52+
53+
# mock gym.ObservationWrapper
54+
class ObservationWrapper(Wrapper):
55+
def __init__(self, env):
56+
super().__init__(env)
57+
58+
def reset(self, **kwargs):
59+
observation = self.env.reset(**kwargs)
60+
return self.observation(observation)
61+
62+
def step(self, action):
63+
observation, reward, done, info = self.env.step(action)
64+
return self.observation(observation), reward, done, info
65+
66+
67+
# mock gym.RewardWrapper
68+
class RewardWrapper(Wrapper):
69+
def __init__(self, env):
70+
super().__init__(env)
71+
72+
73+
# Atari Specific
74+
# mock env.action_space
75+
class ActionSpace(object):
76+
def __init__(self, n, shape=None):
77+
self.n = n
78+
self.shape = shape
79+
80+
81+
# mock env.observation_space
82+
class ObservationSpace(object):
83+
def __init__(self, dim, dtype):
84+
self.shape = dim
85+
self.dtype = dtype
86+
87+
88+
# mock env.spec
89+
class Spec(object):
90+
def __init__(self, id='PongNoFrameskip-v4'):
91+
self.id = id
92+
93+
94+
# mock gym.spaces
95+
class spaces(object):
96+
def __init__(self):
97+
pass
98+
99+
@staticmethod
100+
def Box(high, low, shape, dtype):
101+
return ObservationSpace(shape, dtype)
102+
103+
104+
# mock CartPole-v0
105+
class CartPoleEnv(object):
106+
def __init__(self):
107+
self.observation_space = ObservationSpace((4, ), dtype='int8')
108+
self.action_space = ActionSpace(2)
109+
110+
def step(self, action):
111+
action = int(action)
112+
obs = np.random.random(4) * 2 - 1
113+
reward = np.random.choice([0.0, 1.0])
114+
done = np.random.choice([True, False], p=[0.1, 0.9])
115+
info = {}
116+
return obs, reward, done, info
117+
118+
def reset(self):
119+
obs = np.random.random(4) * 2 - 1
120+
return obs
121+
122+
def seed(self, val):
123+
pass
124+
125+
def close(self):
126+
pass
127+
128+
129+
# mock PongNoFrameskip-v4
130+
class PongEnv(object):
131+
def __init__(self):
132+
class Lives(object):
133+
def lives(self):
134+
return np.random.randint(0, 5)
135+
136+
class Ale(object):
137+
def __init__(self):
138+
self.ale = Lives()
139+
self.np_random = np.random
140+
141+
def get_action_meanings(self):
142+
return ['NOOP'] * 6
143+
144+
self.observation_space = ObservationSpace((210, 160, 3), 'unit8')
145+
self.action_space = ActionSpace(6)
146+
self.unwrapped = Ale()
147+
self.metadata = {'render.modes': []}
148+
self.reward_range = [0, 1]
149+
self.spec = Spec('PongNoFrameskip-v4')
150+
151+
def step(self, action):
152+
action = int(action)
153+
obs = np.random.randint(0, 255, (210, 160, 3), dtype=np.uint8)
154+
reward = np.random.choice([0.0, 1.0])
155+
done = np.random.choice([True, False], p=[0.1, 0.9])
156+
info = {}
157+
return obs, reward, done, info
158+
159+
def reset(self):
160+
obs = np.random.randint(0, 255, (210, 160, 3), dtype=np.uint8)
161+
return obs
162+
163+
def close(self):
164+
pass
165+
166+
def seed(self, val):
167+
pass
168+
169+
170+
# mock mujoco envs
171+
class HalfCheetahEnv(object):
172+
def __init__(self):
173+
self.observation_space = Box(
174+
high=np.array([np.inf] * 17),
175+
low=np.array([-np.inf] * 17),
176+
shape=(17, ),
177+
dtype=None)
178+
self.action_space = Box(
179+
high=np.array([1.0] * 6),
180+
low=np.array([-1.0] * 6),
181+
shape=(6, ),
182+
dtype=None)
183+
self._max_episode_steps = 1000
184+
self._elapsed_steps = 0
185+
186+
def step(self, action):
187+
obs = np.random.randn(17)
188+
reward = np.random.choice([0.0, 1.0])
189+
done = np.random.choice([True, False], p=[0.01, 0.99])
190+
info = {}
191+
return obs, reward, done, info
192+
193+
def reset(self):
194+
obs = np.random.randn(17)
195+
return obs
196+
197+
def seed(self, val):
198+
pass
199+
200+
def close(self):
201+
pass

0 commit comments

Comments
 (0)