Skip to content

Commit 75b2de1

Browse files
araffinWalon1998
andauthored
Recurrent PPO (#53)
* Running (not working yet) version of recurrent PPO * Fixes for multi envs * Save WIP, rework the sampling * Add Box support * Fix sample order * Being cleanup, code is broken (again) * First working version (no shared lstm) * Start cleanup * Try rnn with value function * Re-enable batch size * Deactivate vf rnn * Allow any batch size * Add support for evaluation * Add CNN support * Fix start of sequence * Allow shared LSTM * Rename mask to episode_start * Fix type hint * Enable LSTM for critic * Clean code * Fix for CNN LSTM * Fix sampling with n_layers > 1 * Add std logger * Update wording * Rename and add dict obs support * Fixes for dict obs support * Do not run slow tests * Fix doc * Update recurrent PPO example * Update README * Use Pendulum-v1 for tests * Fix image env * Speedup LSTM forward pass (#63) * added more efficient lstm implementation * Rename and add comment Co-authored-by: Antonin Raffin <[email protected]> * Fixes * Remove OpenAI sampling and improve coverage * Sync with SB3 PPO * Pass state shape and allow lstm kwargs * Update tests * Add masking for padded sequences * Update default in perf test * Remove TODO, mask is now working * Add helper to remove duplicated code, remove hack for padding * Enable LSTM critic and raise threshold for cartpole with no vel * Fix tests * Update doc and tests * Doc fix * Fix for new Sphinx version * Fix doc note * Switch to batch first, no more additional swap * Add comments and mask entropy loss Co-authored-by: Neville Walo <[email protected]>
1 parent cd592a1 commit 75b2de1

File tree

23 files changed

+1988
-28
lines changed

23 files changed

+1988
-28
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ We hope this allows us to provide reliable implementations following stable-base
2525
See documentation for the full list of included features.
2626

2727
**RL Algorithms**:
28-
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
28+
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
2929
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
3030
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
31+
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
32+
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
3133
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
32-
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
3334

3435
**Gym Wrappers**:
3536
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)

docs/guide/algos.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@ along with some useful characteristics: support for discrete/continuous actions,
99
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
1010
============ =========== ============ ================= =============== ================
1111
ARS ✔️ ❌️ ❌ ❌ ✔️
12+
MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️
1213
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
14+
RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️
1315
TQC ✔️ ❌ ❌ ❌ ✔️
1416
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
1517
============ =========== ============ ================= =============== ================
1618

1719

1820
.. note::
19-
Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
21+
``Tuple`` observation spaces are not supported by any environment,
22+
however, single-level ``Dict`` spaces are
2023

2124
Actions ``gym.spaces``:
2225

docs/guide/examples.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,38 @@ Train an agent using Augmented Random Search (ARS) agent on the Pendulum environ
7171
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
7272
model.learn(total_timesteps=10000, log_interval=4)
7373
model.save("ars_pendulum")
74+
75+
RecurrentPPO
76+
------------
77+
78+
Train a PPO agent with a recurrent policy on the CartPole environment.
79+
80+
81+
.. note::
82+
83+
It is particularly important to pass the ``lstm_states``
84+
and ``episode_start`` argument to the ``predict()`` method,
85+
so the cell and hidden states of the LSTM are correctly updated.
86+
87+
88+
.. code-block:: python
89+
90+
import numpy as np
91+
92+
from sb3_contrib import RecurrentPPO
93+
94+
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
95+
model.learn(5000)
96+
97+
env = model.get_env()
98+
obs = env.reset()
99+
# cell and hidden state of the LSTM
100+
lstm_states = None
101+
num_envs = 1
102+
# Episode start signals are used to reset the lstm states
103+
episode_starts = np.ones((num_envs,), dtype=bool)
104+
while True:
105+
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
106+
obs, rewards, dones, info = env.step(action)
107+
episode_starts = dones
108+
env.render()

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
3333

3434
modules/ars
3535
modules/ppo_mask
36+
modules/ppo_recurrent
3637
modules/qrdqn
3738
modules/tqc
3839
modules/trpo

docs/misc/changelog.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
Changelog
44
==========
55

6-
Release 1.5.1a7 (WIP)
6+
Release 1.5.1a8 (WIP)
77
-------------------------------
88

9+
**Add RecurrentPPO (aka PPO LSTM)**
10+
911
Breaking Changes:
1012
^^^^^^^^^^^^^^^^^
1113
- Upgraded to Stable-Baselines3 >= 1.5.1a7
@@ -17,6 +19,7 @@ Breaking Changes:
1719

1820
New Features:
1921
^^^^^^^^^^^^^
22+
- Added ``RecurrentPPO`` (aka PPO LSTM)
2023

2124
Bug Fixes:
2225
^^^^^^^^^^
@@ -34,7 +37,8 @@ Breaking Changes:
3437

3538
New Features:
3639
^^^^^^^^^^^^^
37-
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) @vwxyzjn
40+
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) (@vwxyzjn)
41+
3842

3943
Bug Fixes:
4044
^^^^^^^^^^
@@ -46,6 +50,9 @@ Deprecations:
4650
Others:
4751
^^^^^^^
4852

53+
Documentation:
54+
^^^^^^^^^^^^^^
55+
4956
Release 1.4.0 (2022-01-19)
5057
-------------------------------
5158
**Add Trust Region Policy Optimization (TRPO) and Augmented Random Search (ARS) algorithms**

docs/modules/ppo_mask.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Maskable PPO
66
============
77

8-
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization(PPO)
8+
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization (PPO)
99
algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm.
1010

1111

docs/modules/ppo_recurrent.rst

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
.. _ppo_lstm:
2+
3+
.. automodule:: sb3_contrib.ppo_recurrent
4+
5+
Recurrent PPO
6+
=============
7+
8+
Implementation of recurrent policies for the Proximal Policy Optimization (PPO)
9+
algorithm. Other than adding support for recurrent policies (LSTM here), the behavior is the same as in SB3's core PPO algorithm.
10+
11+
12+
.. rubric:: Available Policies
13+
14+
.. autosummary::
15+
:nosignatures:
16+
17+
MlpLstmPolicy
18+
CnnLstmPolicy
19+
MultiInputLstmPolicy
20+
21+
22+
Notes
23+
-----
24+
25+
- Blog post: https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/
26+
27+
28+
Can I use?
29+
----------
30+
31+
- Recurrent policies: ✔️
32+
- Multi processing: ✔️
33+
- Gym spaces:
34+
35+
36+
============= ====== ===========
37+
Space Action Observation
38+
============= ====== ===========
39+
Discrete ✔️ ✔️
40+
Box ✔️ ✔️
41+
MultiDiscrete ✔️ ✔️
42+
MultiBinary ✔️ ✔️
43+
Dict ❌ ✔️
44+
============= ====== ===========
45+
46+
47+
Example
48+
-------
49+
50+
.. note::
51+
52+
It is particularly important to pass the ``lstm_states``
53+
and ``episode_start`` argument to the ``predict()`` method,
54+
so the cell and hidden states of the LSTM are correctly updated.
55+
56+
57+
.. code-block:: python
58+
59+
import numpy as np
60+
61+
from sb3_contrib import RecurrentPPO
62+
from stable_baselines3.common.evaluation import evaluate_policy
63+
64+
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
65+
model.learn(5000)
66+
67+
env = model.get_env()
68+
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
69+
print(mean_reward)
70+
71+
model.save("ppo_recurrent")
72+
del model # remove to demonstrate saving and loading
73+
74+
model = RecurrentPPO.load("ppo_recurrent")
75+
76+
obs = env.reset()
77+
# cell and hidden state of the LSTM
78+
lstm_states = None
79+
num_envs = 1
80+
# Episode start signals are used to reset the lstm states
81+
episode_starts = np.ones((num_envs,), dtype=bool)
82+
while True:
83+
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
84+
obs, rewards, dones, info = env.step(action)
85+
episode_starts = dones
86+
env.render()
87+
88+
89+
90+
Results
91+
-------
92+
93+
Report on environments with masked velocity (with and without framestack) can be found here: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4
94+
95+
``RecurrentPPO`` was evaluated against PPO on:
96+
97+
- PendulumNoVel-v1
98+
- LunarLanderNoVel-v2
99+
- CartPoleNoVel-v1
100+
- MountainCarContinuousNoVel-v0
101+
- CarRacing-v0
102+
103+
How to replicate the results?
104+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
105+
106+
Clone the repo for the experiment:
107+
108+
.. code-block:: bash
109+
110+
git clone https://github.com/DLR-RM/rl-baselines3-zoo
111+
cd rl-baselines3-zoo
112+
git checkout feat/recurrent-ppo
113+
114+
115+
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
116+
117+
.. code-block:: bash
118+
119+
python train.py --algo ppo_lstm --env $ENV_ID --eval-episodes 10 --eval-freq 10000
120+
121+
122+
Parameters
123+
----------
124+
125+
.. autoclass:: RecurrentPPO
126+
:members:
127+
:inherited-members:
128+
129+
130+
RecurrentPPO Policies
131+
---------------------
132+
133+
.. autoclass:: MlpLstmPolicy
134+
:members:
135+
:inherited-members:
136+
137+
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy
138+
:members:
139+
:noindex:
140+
141+
.. autoclass:: CnnLstmPolicy
142+
:members:
143+
144+
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticCnnPolicy
145+
:members:
146+
:noindex:
147+
148+
.. autoclass:: MultiInputLstmPolicy
149+
:members:
150+
151+
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentMultiInputActorCriticPolicy
152+
:members:
153+
:noindex:

sb3_contrib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from sb3_contrib.ars import ARS
44
from sb3_contrib.ppo_mask import MaskablePPO
5+
from sb3_contrib.ppo_recurrent import RecurrentPPO
56
from sb3_contrib.qrdqn import QRDQN
67
from sb3_contrib.tqc import TQC
78
from sb3_contrib.trpo import TRPO

sb3_contrib/common/maskable/policies.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,12 @@ def predict(
215215
action_masks: Optional[np.ndarray] = None,
216216
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
217217
"""
218-
Get the policy action and state from an observation (and optional state).
218+
Get the policy action from an observation (and optional hidden state).
219219
Includes sugar-coating to handle different observations (e.g. normalizing images).
220220
221221
:param observation: the input observation
222222
:param state: The last states (can be None, used in recurrent policies)
223-
:param mask: The last masks (can be None, used in recurrent policies)
223+
:param episode_start: The last masks (can be None, used in recurrent policies)
224224
:param deterministic: Whether or not to return deterministic actions.
225225
:param action_masks: Action masks to apply to the action distribution
226226
:return: the model's action and the next state
@@ -229,8 +229,8 @@ def predict(
229229
# TODO (GH/1): add support for RNN policies
230230
# if state is None:
231231
# state = self.initial_state
232-
# if mask is None:
233-
# mask = [False for _ in range(self.n_envs)]
232+
# if episode_start is None:
233+
# episode_start = [False for _ in range(self.n_envs)]
234234

235235
# Switch to eval mode (this affects batch norm / dropout)
236236
self.set_training_mode(False)
@@ -256,7 +256,7 @@ def predict(
256256
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
257257
actions = actions[0]
258258

259-
return actions, state
259+
return actions, None
260260

261261
def evaluate_actions(
262262
self,

sb3_contrib/common/recurrent/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)