Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .teamcity/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ function run_example_test {
python -m pip uninstall -r ./examples/DQN_variant/requirements.txt -y

python -m pip install -r ./examples/PPO/requirements_atari.txt
python examples/PPO/train.py --train_total_steps 5000 --env PongNoFrameskip-v4
python examples/PPO/atari/train.py --train_total_steps 5000 --env PongNoFrameskip-v4
python -m pip uninstall -r ./examples/PPO/requirements_atari.txt -y

xparl start --port 8010 --cpu_num 8
python -m pip install -r ./examples/PPO/requirements_mujoco.txt
python examples/PPO/train.py --train_total_steps 5000 --env HalfCheetah-v4 --continuous_action
python examples/PPO/mujoco/train.py --env 'HalfCheetah-v2' --train_total_episodes 100 --env_num 5
python -m pip uninstall -r ./examples/PPO/requirements_mujoco.txt -y
xparl stop

python -m pip install -r ./examples/SAC/requirements.txt
python examples/SAC/train.py --train_total_steps 5000 --env HalfCheetah-v4
Expand Down
35 changes: 18 additions & 17 deletions examples/PPO/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ PARL currently supports the open-source version of Mujoco provided by DeepMind,
+ python3.7+
+ [paddle>=2.3.1](https://github.com/PaddlePaddle/Paddle)
+ [parl>=2.1.1](https://github.com/PaddlePaddle/PARL)
+ gym>=0.26.0
+ gym==0.18.0
+ mujoco>=2.2.2

### Atari-Dependencies:
Expand All @@ -34,17 +34,8 @@ PARL currently supports the open-source version of Mujoco provided by DeepMind,
+ atari-py==0.2.6
+ opencv-python

### Training:

```
# To train an agent for discrete action game (Atari: PongNoFrameskip-v4 by default)
python train.py

# To train an agent for continuous action game (Mujoco)
python train.py --env 'HalfCheetah-v4' --continuous_action --train_total_steps 1000000
```

### Distributed Training
### Training Mujoco Distributedly
Accelerate training process by setting `xparl_addr` and `env_num > 1` when environment simulation running very slow.
At first, we can start a local cluster with 8 CPUs:

Expand All @@ -56,14 +47,24 @@ Note that if you have started a master before, you don't have to run the above
command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html).

Then we can start the distributed training by running:
Then we can start the distributed training for mujoco games by running:

```
# To train an agent distributedly
cd mujoco

# for discrete action game (Atari games)
python train.py --env "PongNoFrameskip-v4" --env_num 8 --xparl_addr 'localhost:8010'
python train.py --env 'HalfCheetah-v2' --train_total_episodes 100000 --env_num 5
```

# for continuous action game (Mujoco games)
python train.py --env 'HalfCheetah-v4' --continuous_action --train_total_steps 1000000 --env_num 5 --xparl_addr 'localhost:8010'

### Training Atari
To train an agent for discrete action game (Atari: PongNoFrameskip-v4 by default):

```
cd atari

# Local training
python train.py
# Distributed training
xparl start --port 8010 --cpu_num 8
python train.py --env "PongNoFrameskip-v4" --env_num 8 --xparl_addr 'localhost:8010'
```
15 changes: 6 additions & 9 deletions examples/PPO/agent.py → examples/PPO/atari/atari_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from parl.utils.scheduler import LinearDecayScheduler


class PPOAgent(parl.Agent):
class AtariAgent(parl.Agent):
""" Agent of PPO env

Args:
Expand All @@ -27,12 +27,11 @@ class PPOAgent(parl.Agent):
"""

def __init__(self, algorithm, config):
super(PPOAgent, self).__init__(algorithm)
super(AtariAgent, self).__init__(algorithm)

self.config = config
if self.config['lr_decay']:
self.lr_scheduler = LinearDecayScheduler(
self.config['initial_lr'], self.config['num_updates'])
self.lr_scheduler = LinearDecayScheduler(self.config['initial_lr'], self.config['num_updates'])

def predict(self, obs):
""" Predict action from current policy given observation
Expand Down Expand Up @@ -85,8 +84,7 @@ def learn(self, rollout):
else:
lr = None

minibatch_size = int(
self.config['batch_size'] // self.config['num_minibatches'])
minibatch_size = int(self.config['batch_size'] // self.config['num_minibatches'])

indexes = np.arange(self.config['batch_size'])
for epoch in range(self.config['update_epochs']):
Expand All @@ -105,9 +103,8 @@ def learn(self, rollout):
batch_return = paddle.to_tensor(batch_return)
batch_value = paddle.to_tensor(batch_value)

value_loss, action_loss, entropy_loss = self.alg.learn(
batch_obs, batch_action, batch_value, batch_return,
batch_logprob, batch_adv, lr)
value_loss, action_loss, entropy_loss = self.alg.learn(batch_obs, batch_action, batch_value,
batch_return, batch_logprob, batch_adv, lr)

value_loss_epoch += value_loss
action_loss_epoch += action_loss
Expand Down
File renamed without changes.
File renamed without changes.
84 changes: 15 additions & 69 deletions examples/PPO/env_utils.py → examples/PPO/atari/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
import gym
import numpy as np
from parl.utils import logger
from parl.env.atari_wrappers import wrap_deepmind

TEST_EPISODE = 3
# wrapper parameters for atari env
ENV_DIM = 84
OBS_FORMAT = 'NCHW'
# wrapper parameters for mujoco env
GAMMA = 0.99


class ParallelEnv(object):
Expand All @@ -39,14 +37,9 @@ def __init__(self, config=None):
base_env = LocalEnv

if config['seed']:
self.env_list = [
base_env(config['env'], config['seed'] + i)
for i in range(self.env_num)
]
self.env_list = [base_env(config['env'], config['seed'] + i) for i in range(self.env_num)]
else:
self.env_list = [
base_env(config['env']) for _ in range(self.env_num)
]
self.env_list = [base_env(config['env']) for _ in range(self.env_num)]
if hasattr(self.env_list[0], '_max_episode_steps'):
self._max_episode_steps = self.env_list[0]._max_episode_steps
else:
Expand All @@ -68,10 +61,7 @@ def reset(self):
def step(self, action_list):
next_obs_list, reward_list, done_list, info_list = [], [], [], []
if self.use_xparl:
return_list = [
self.env_list[i].step(action_list[i])
for i in range(self.env_num)
]
return_list = [self.env_list[i].step(action_list[i]) for i in range(self.env_num)]
return_list = [return_.get() for return_ in return_list]
return_list = np.array(return_list, dtype=object)

Expand All @@ -89,8 +79,7 @@ def step(self, action_list):
done = done_[i]
info = info_[i]
else:
next_obs, reward, done, info = self.env_list[i].step(
action_list[i])
next_obs, reward, done, info = self.env_list[i].step(action_list[i])

self.episode_steps_list[i] += 1
self.episode_reward_list[i] += reward
Expand All @@ -104,49 +93,26 @@ def step(self, action_list):
next_obs = self.env_list[i].reset()
self.episode_steps_list[i] = 0
self.episode_reward_list[i] = 0
if self.env_list[i].continuous_action:
# get running mean and variance of obs
self.eval_ob_rms = self.env_list[i].env.get_ob_rms()

next_obs_list.append(next_obs)
reward_list.append(reward)
done_list.append(done)
info_list.append(info)
return np.array(next_obs_list), np.array(reward_list), np.array(
done_list), np.array(info_list)
return np.array(next_obs_list), np.array(reward_list), np.array(done_list), np.array(info_list)


class LocalEnv(object):
def __init__(self, env_name, env_seed=None, test=False, ob_rms=None):
env = gym.make(env_name)

# is instance of gym.spaces.Box
if hasattr(env.action_space, 'high'):
from parl.env.mujoco_wrappers import wrap_rms
self._max_episode_steps = env._max_episode_steps
self.continuous_action = True
if test:
self.env = wrap_rms(env, GAMMA, test=True, ob_rms=ob_rms)
else:
self.env = wrap_rms(env, gamma=GAMMA)
# is instance of gym.spaces.Discrete
elif hasattr(env.action_space, 'n'):
from parl.env.atari_wrappers import wrap_deepmind
self.continuous_action = False
if hasattr(env.action_space, 'n'):
if test:
self.env = wrap_deepmind(
env,
dim=ENV_DIM,
obs_format=OBS_FORMAT,
test=True,
test_episodes=1)
self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT, test=True, test_episodes=1)
else:
self.env = wrap_deepmind(
env, dim=ENV_DIM, obs_format=OBS_FORMAT)
self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT)
else:
raise AssertionError(
'act_space must be instance of gym.spaces.Box or gym.spaces.Discrete'
)
raise AssertionError('act_space must be instance of gym.spaces.Discrete')

self.obs_space = self.env.observation_space
self.act_space = self.env.action_space
Expand All @@ -166,31 +132,13 @@ class RemoteEnv(object):
def __init__(self, env_name, env_seed=None, test=False, ob_rms=None):
env = gym.make(env_name)

if hasattr(env.action_space, 'high'):
from parl.env.mujoco_wrappers import wrap_rms
self._max_episode_steps = env._max_episode_steps
self.continuous_action = True
if test:
self.env = wrap_rms(env, GAMMA, test=True, ob_rms=ob_rms)
else:
self.env = wrap_rms(env, gamma=GAMMA)
elif hasattr(env.action_space, 'n'):
from parl.env.atari_wrappers import wrap_deepmind
self.continuous_action = False
if hasattr(env.action_space, 'n'):
if test:
self.env = wrap_deepmind(
env,
dim=ENV_DIM,
obs_format=OBS_FORMAT,
test=True,
test_episodes=1)
self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT, test=True, test_episodes=1)
else:
self.env = wrap_deepmind(
env, dim=ENV_DIM, obs_format=OBS_FORMAT)
self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT)
else:
raise AssertionError(
'act_space must be instance of gym.spaces.Box or gym.spaces.Discrete'
)
raise AssertionError('act_space must be instance of gym.spaces.Discrete')
if env_seed:
self.env.seed(env_seed)

Expand All @@ -201,6 +149,4 @@ def step(self, action):
return self.env.step(action)

def render(self):
return logger.warning(
'Can not render in remote environment, render() have been skipped.'
)
return logger.warning('Can not render in remote environment, render() have been skipped.')
15 changes: 5 additions & 10 deletions examples/PPO/storage.py → examples/PPO/atari/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

class RolloutStorage():
def __init__(self, step_nums, env_num, obs_space, act_space):
self.obs = np.zeros(
(step_nums, env_num) + obs_space.shape, dtype='float32')
self.actions = np.zeros(
(step_nums, env_num) + act_space.shape, dtype='float32')
self.obs = np.zeros((step_nums, env_num) + obs_space.shape, dtype='float32')
self.actions = np.zeros((step_nums, env_num) + act_space.shape, dtype='float32')
self.logprobs = np.zeros((step_nums, env_num), dtype='float32')
self.rewards = np.zeros((step_nums, env_num), dtype='float32')
self.dones = np.zeros((step_nums, env_num), dtype='float32')
Expand Down Expand Up @@ -54,10 +52,8 @@ def compute_returns(self, value, done, gamma=0.99, gae_lambda=0.95):
else:
nextnonterminal = 1.0 - self.dones[t + 1]
nextvalues = self.values[t + 1]
delta = self.rewards[
t] + gamma * nextvalues * nextnonterminal - self.values[t]
advantages[
t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
delta = self.rewards[t] + gamma * nextvalues * nextnonterminal - self.values[t]
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
returns = advantages + self.values
self.returns = returns
self.advantages = advantages
Expand All @@ -72,5 +68,4 @@ def sample_batch(self, idx):
b_returns = self.returns.reshape(-1)
b_values = self.values.reshape(-1)

return b_obs[idx], b_actions[idx], b_logprobs[idx], b_advantages[
idx], b_returns[idx], b_values[idx]
return b_obs[idx], b_actions[idx], b_logprobs[idx], b_advantages[idx], b_returns[idx], b_values[idx]
Loading