Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,38 @@ Train a PPO agent with a recurrent policy on the CartPole environment.
episode_starts = dones
vec_env.render("human")


.. note::

You can also use a recurrent policy based on vanilla RNNs rather than LSTMs.


.. code-block:: python

import numpy as np

from sb3_contrib import RecurrentPPO

policy_kwargs = {"recurrent_layer_type": "rnn"} # The default is "lstm"

model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(5000)

vec_env = model.get_env()
obs = vec_env.reset()
# Hidden state of the RNN
state = None
num_envs = 1
# Episode start signals are used to reset the rnn state
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, state = model.predict(obs, state=state, episode_start=episode_starts, deterministic=True)
# Note: vectorized environment resets automatically
obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones
vec_env.render("human")


CrossQ
------

Expand Down
7 changes: 7 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
Changelog
==========

TBD (2025-06-11)
--------------------------

New Features:
^^^^^^^^^^^^^
- Added support for vanilla RNNs in RecurrentPPO via `recurrent_layer_type` policy_kwargs argument (@gcroci2)

Release 2.6.1a1 (WIP)
--------------------------

Expand Down
30 changes: 30 additions & 0 deletions docs/modules/ppo_recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,36 @@ Example
vec_env.render("human")


.. note::

You can also use a recurrent policy based on vanilla RNNs rather than LSTMs.


.. code-block:: python

import numpy as np

from sb3_contrib import RecurrentPPO

policy_kwargs = {"recurrent_layer_type": "rnn"} # The default is "lstm"

model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(5000)

vec_env = model.get_env()
obs = vec_env.reset()
# Hidden state of the RNN
state = None
num_envs = 1
# Episode start signals are used to reset the rnn state
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, state = model.predict(obs, state=state, episode_start=episode_starts, deterministic=True)
# Note: vectorized environment resets automatically
obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones
vec_env.render("human")


Results
-------
Expand Down
124 changes: 86 additions & 38 deletions sb3_contrib/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class RecurrentRolloutBuffer(RolloutBuffer):
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
:param recurrent_layer_type: Type of recurrent layer ("lstm" or "rnn")
"""

def __init__(
Expand All @@ -121,9 +122,11 @@ def __init__(
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
recurrent_layer_type: str = "lstm",
):
self.hidden_state_shape = hidden_state_shape
self.seq_start_indices, self.seq_end_indices = None, None
self.recurrent_layer_type = recurrent_layer_type.lower()
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)

def reset(self):
Expand All @@ -137,10 +140,25 @@ def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
"""
:param hidden_states: LSTM cell and hidden state
"""
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
# Actor states
if self.recurrent_layer_type == "lstm":
# LSTM case: (hidden, cell) tuple
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
else:
# RNN case: single hidden state tensor
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi.cpu().numpy())
self.cell_states_pi[self.pos] = np.zeros_like(self.hidden_states_pi[self.pos])

# Critic states
if self.recurrent_layer_type == "lstm":
# LSTM case: (hidden, cell) tuple
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
else:
# RNN case: single hidden state tensor
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf.cpu().numpy())
self.cell_states_vf[self.pos] = np.zeros_like(self.hidden_states_vf[self.pos])

super().add(*args, **kwargs)

Expand Down Expand Up @@ -211,22 +229,30 @@ def _get_samples(
n_seq = len(self.seq_start_indices)
max_length = self.pad(self.actions[batch_inds]).shape[1]
padded_batch_size = n_seq * max_length
# We retrieve the lstm hidden states that will allow
# to properly initialize the LSTM at the beginning of each sequence
lstm_states_pi = (
# 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim)
# 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim)
# 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_vf = (
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())

# Use the stored recurrent layer type to determine state structure
if self.recurrent_layer_type == "lstm":
# LSTM case: return tuples
lstm_states_pi = (
# 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim)
# 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim)
# 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_vf = (
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
else:
# RNN case: return single tensors (only hidden states, no cell states)
lstm_states_pi = self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1)
lstm_states_vf = self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1)
lstm_states_pi = self.to_torch(lstm_states_pi).contiguous()
lstm_states_vf = self.to_torch(lstm_states_vf).contiguous()

return RecurrentRolloutBufferSamples(
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
Expand Down Expand Up @@ -256,6 +282,7 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer):
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
:param recurrent_layer_type: Type of recurrent layer ("lstm" or "rnn")
"""

def __init__(
Expand All @@ -268,9 +295,11 @@ def __init__(
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
recurrent_layer_type: str = "lstm",
):
self.hidden_state_shape = hidden_state_shape
self.seq_start_indices, self.seq_end_indices = None, None
self.recurrent_layer_type = recurrent_layer_type.lower()
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)

def reset(self):
Expand All @@ -284,10 +313,21 @@ def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
"""
:param hidden_states: LSTM cell and hidden state
"""
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
# Actor
if self.recurrent_layer_type == "lstm":
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
else:
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi.cpu().numpy())
self.cell_states_pi[self.pos] = np.zeros_like(self.hidden_states_pi[self.pos])

# Critic
if self.recurrent_layer_type == "lstm":
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
else:
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf.cpu().numpy())
self.cell_states_vf[self.pos] = np.zeros_like(self.hidden_states_vf[self.pos])

super().add(*args, **kwargs)

Expand Down Expand Up @@ -354,20 +394,28 @@ def _get_samples(
n_seq = len(self.seq_start_indices)
max_length = self.pad(self.actions[batch_inds]).shape[1]
padded_batch_size = n_seq * max_length
# We retrieve the lstm hidden states that will allow
# to properly initialize the LSTM at the beginning of each sequence
lstm_states_pi = (
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_vf = (
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())

# Use the stored recurrent layer type to determine state structure
if self.recurrent_layer_type == "lstm":
# LSTM case: return tuples
lstm_states_pi = (
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_vf = (
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
else:
# RNN case: return single tensors (only hidden states, no cell states)
lstm_states_pi = self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1)
lstm_states_vf = self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1)
lstm_states_pi = self.to_torch(lstm_states_pi).contiguous()
lstm_states_vf = self.to_torch(lstm_states_vf).contiguous()

observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
Expand Down
Loading