diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 5e152265..9603b299 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -114,6 +114,38 @@ Train a PPO agent with a recurrent policy on the CartPole environment. episode_starts = dones vec_env.render("human") + +.. note:: + + You can also use a recurrent policy based on vanilla RNNs rather than LSTMs. + + +.. code-block:: python + + import numpy as np + + from sb3_contrib import RecurrentPPO + + policy_kwargs = {"recurrent_layer_type": "rnn"} # The default is "lstm" + + model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + model.learn(5000) + + vec_env = model.get_env() + obs = vec_env.reset() + # Hidden state of the RNN + state = None + num_envs = 1 + # Episode start signals are used to reset the rnn state + episode_starts = np.ones((num_envs,), dtype=bool) + while True: + action, state = model.predict(obs, state=state, episode_start=episode_starts, deterministic=True) + # Note: vectorized environment resets automatically + obs, rewards, dones, info = vec_env.step(action) + episode_starts = dones + vec_env.render("human") + + CrossQ ------ diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fa37555e..3705a2d1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,13 @@ Changelog ========== +TBD (2025-06-11) +-------------------------- + +New Features: +^^^^^^^^^^^^^ +- Added support for vanilla RNNs in RecurrentPPO via `recurrent_layer_type` policy_kwargs argument (@gcroci2) + Release 2.6.1a1 (WIP) -------------------------- diff --git a/docs/modules/ppo_recurrent.rst b/docs/modules/ppo_recurrent.rst index 31e3d340..d6094eb0 100644 --- a/docs/modules/ppo_recurrent.rst +++ b/docs/modules/ppo_recurrent.rst @@ -86,6 +86,36 @@ Example vec_env.render("human") +.. note:: + + You can also use a recurrent policy based on vanilla RNNs rather than LSTMs. + + +.. code-block:: python + + import numpy as np + + from sb3_contrib import RecurrentPPO + + policy_kwargs = {"recurrent_layer_type": "rnn"} # The default is "lstm" + + model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + model.learn(5000) + + vec_env = model.get_env() + obs = vec_env.reset() + # Hidden state of the RNN + state = None + num_envs = 1 + # Episode start signals are used to reset the rnn state + episode_starts = np.ones((num_envs,), dtype=bool) + while True: + action, state = model.predict(obs, state=state, episode_start=episode_starts, deterministic=True) + # Note: vectorized environment resets automatically + obs, rewards, dones, info = vec_env.step(action) + episode_starts = dones + vec_env.render("human") + Results ------- diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 5386db11..3fcf338d 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -109,6 +109,7 @@ class RecurrentRolloutBuffer(RolloutBuffer): Equivalent to classic advantage when set to 1. :param gamma: Discount factor :param n_envs: Number of parallel environments + :param recurrent_layer_type: Type of recurrent layer ("lstm" or "rnn") """ def __init__( @@ -121,9 +122,11 @@ def __init__( gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, + recurrent_layer_type: str = "lstm", ): self.hidden_state_shape = hidden_state_shape self.seq_start_indices, self.seq_end_indices = None, None + self.recurrent_layer_type = recurrent_layer_type.lower() super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): @@ -137,10 +140,25 @@ def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) - self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) - self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) - self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + # Actor states + if self.recurrent_layer_type == "lstm": + # LSTM case: (hidden, cell) tuple + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + else: + # RNN case: single hidden state tensor + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi.cpu().numpy()) + self.cell_states_pi[self.pos] = np.zeros_like(self.hidden_states_pi[self.pos]) + + # Critic states + if self.recurrent_layer_type == "lstm": + # LSTM case: (hidden, cell) tuple + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + else: + # RNN case: single hidden state tensor + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf.cpu().numpy()) + self.cell_states_vf[self.pos] = np.zeros_like(self.hidden_states_vf[self.pos]) super().add(*args, **kwargs) @@ -211,22 +229,30 @@ def _get_samples( n_seq = len(self.seq_start_indices) max_length = self.pad(self.actions[batch_inds]).shape[1] padded_batch_size = n_seq * max_length - # We retrieve the lstm hidden states that will allow - # to properly initialize the LSTM at the beginning of each sequence - lstm_states_pi = ( - # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim) - # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim) - # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim) - self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), - self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), - ) - lstm_states_vf = ( - # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) - self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), - self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), - ) - lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) - lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + + # Use the stored recurrent layer type to determine state structure + if self.recurrent_layer_type == "lstm": + # LSTM case: return tuples + lstm_states_pi = ( + # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim) + # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim) + # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_vf = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + else: + # RNN case: return single tensors (only hidden states, no cell states) + lstm_states_pi = self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1) + lstm_states_vf = self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1) + lstm_states_pi = self.to_torch(lstm_states_pi).contiguous() + lstm_states_vf = self.to_torch(lstm_states_vf).contiguous() return RecurrentRolloutBufferSamples( # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) @@ -256,6 +282,7 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer): Equivalent to classic advantage when set to 1. :param gamma: Discount factor :param n_envs: Number of parallel environments + :param recurrent_layer_type: Type of recurrent layer ("lstm" or "rnn") """ def __init__( @@ -268,9 +295,11 @@ def __init__( gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, + recurrent_layer_type: str = "lstm", ): self.hidden_state_shape = hidden_state_shape self.seq_start_indices, self.seq_end_indices = None, None + self.recurrent_layer_type = recurrent_layer_type.lower() super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) def reset(self): @@ -284,10 +313,21 @@ def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) - self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) - self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) - self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + # Actor + if self.recurrent_layer_type == "lstm": + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + else: + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi.cpu().numpy()) + self.cell_states_pi[self.pos] = np.zeros_like(self.hidden_states_pi[self.pos]) + + # Critic + if self.recurrent_layer_type == "lstm": + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + else: + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf.cpu().numpy()) + self.cell_states_vf[self.pos] = np.zeros_like(self.hidden_states_vf[self.pos]) super().add(*args, **kwargs) @@ -354,20 +394,28 @@ def _get_samples( n_seq = len(self.seq_start_indices) max_length = self.pad(self.actions[batch_inds]).shape[1] padded_batch_size = n_seq * max_length - # We retrieve the lstm hidden states that will allow - # to properly initialize the LSTM at the beginning of each sequence - lstm_states_pi = ( - # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) - self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), - self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), - ) - lstm_states_vf = ( - # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) - self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), - self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), - ) - lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) - lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + + # Use the stored recurrent layer type to determine state structure + if self.recurrent_layer_type == "lstm": + # LSTM case: return tuples + lstm_states_pi = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_vf = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + else: + # RNN case: return single tensors (only hidden states, no cell states) + lstm_states_pi = self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1) + lstm_states_vf = self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1) + lstm_states_pi = self.to_torch(lstm_states_pi).contiguous() + lstm_states_vf = self.to_torch(lstm_states_vf).contiguous() observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index ef7b987e..ed71413a 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -52,6 +52,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param recurrent_layer_type: Type of recurrent layer to use (LSTM or vanilla RNN). + By default, LSTM is used. :param n_lstm_layers: Number of LSTM layers. :param shared_lstm: Whether the LSTM is shared between the actor and the critic (in that case, only the actor gradient is used) @@ -80,6 +82,7 @@ def __init__( normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[dict[str, Any]] = None, + recurrent_layer_type: str = "lstm", lstm_hidden_size: int = 256, n_lstm_layers: int = 1, shared_lstm: bool = False, @@ -107,10 +110,15 @@ def __init__( optimizer_kwargs, ) + self.recurrent_layer_type = recurrent_layer_type.lower() + assert self.recurrent_layer_type in ["lstm", "rnn"], "Invalid recurrent_layer_type" + + rnn_class = nn.LSTM if self.recurrent_layer_type == "lstm" else nn.RNN + self.lstm_kwargs = lstm_kwargs or {} self.shared_lstm = shared_lstm self.enable_critic_lstm = enable_critic_lstm - self.lstm_actor = nn.LSTM( + self.lstm_actor = rnn_class( self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers, @@ -137,7 +145,7 @@ def __init__( # Use a separate LSTM for the critic if self.enable_critic_lstm: - self.lstm_critic = nn.LSTM( + self.lstm_critic = rnn_class( self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers, @@ -162,10 +170,10 @@ def _build_mlp_extractor(self) -> None: @staticmethod def _process_sequence( features: th.Tensor, - lstm_states: tuple[th.Tensor, th.Tensor], + lstm_states: Union[tuple[th.Tensor, th.Tensor], th.Tensor], episode_starts: th.Tensor, - lstm: nn.LSTM, - ) -> tuple[th.Tensor, th.Tensor]: + lstm: Union[nn.LSTM, nn.RNN], + ) -> tuple[th.Tensor, Union[tuple[th.Tensor, th.Tensor], th.Tensor]]: """ Do a forward pass in the LSTM network. @@ -179,7 +187,8 @@ def _process_sequence( # LSTM logic # (sequence length, batch size, features dim) # (batch size = n_envs for data collection or n_seq when doing gradient update) - n_seq = lstm_states[0].shape[1] + is_lstm = isinstance(lstm, nn.LSTM) + n_seq = lstm_states[0].shape[1] if is_lstm else lstm_states.shape[1] # Batch to sequence # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) # note: max length (max sequence length) is always 1 during data collection @@ -196,14 +205,24 @@ def _process_sequence( lstm_output = [] # Iterate over the sequence for features, episode_start in zip_strict(features_sequence, episode_starts): - hidden, lstm_states = lstm( - features.unsqueeze(dim=0), - ( - # Reset the states at the beginning of a new episode - (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0], - (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1], - ), - ) + if is_lstm: + hidden, lstm_states = lstm( + features.unsqueeze(dim=0), + ( + # Reset the states at the beginning of a new episode + (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0], + (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1], + ), + ) + else: + hidden, lstm_states = lstm( + features.unsqueeze(dim=0), + ( + # Reset the states at the beginning of a new episode + (1.0 - episode_start).view(1, n_seq, 1) + * lstm_states + ), + ) lstm_output += [hidden] # Sequence to batch # (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim) @@ -230,7 +249,7 @@ def forward( # Preprocess the observation if needed features = self.extract_features(obs) if self.share_features_extractor: - pi_features = vf_features = features # alis + pi_features = vf_features = features # alias else: pi_features, vf_features = features # latent_pi, latent_vf = self.mlp_extractor(features) @@ -240,7 +259,10 @@ def forward( elif self.shared_lstm: # Re-use LSTM features but do not backpropagate latent_vf = latent_pi.detach() - lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) + if self.recurrent_layer_type == "lstm": + lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) + else: + lstm_states_vf = lstm_states_pi.detach() else: # Critic only has a feedforward network latent_vf = self.critic(vf_features) @@ -259,7 +281,7 @@ def forward( def get_distribution( self, obs: th.Tensor, - lstm_states: tuple[th.Tensor, th.Tensor], + lstm_states: Union[tuple[th.Tensor, th.Tensor], th.Tensor], episode_starts: th.Tensor, ) -> tuple[Distribution, tuple[th.Tensor, ...]]: """ @@ -280,7 +302,7 @@ def get_distribution( def predict_values( self, obs: th.Tensor, - lstm_states: tuple[th.Tensor, th.Tensor], + lstm_states: Union[tuple[th.Tensor, th.Tensor], th.Tensor], episode_starts: th.Tensor, ) -> th.Tensor: """ @@ -347,7 +369,7 @@ def evaluate_actions( def _predict( self, observation: th.Tensor, - lstm_states: tuple[th.Tensor, th.Tensor], + lstm_states: Union[tuple[th.Tensor, th.Tensor], th.Tensor], episode_starts: th.Tensor, deterministic: bool = False, ) -> tuple[th.Tensor, tuple[th.Tensor, ...]]: @@ -392,25 +414,47 @@ def predict( n_envs = observation[next(iter(observation.keys()))].shape[0] else: n_envs = observation.shape[0] - # state : (n_layers, n_envs, dim) + + # state : (n_layers, n_envs, dim) for both LSTM and RNN + # For LSTM: state is a tuple (hidden_state, cell_state) + # For RNN: state is just the hidden_state if state is None: # Initialize hidden states to zeros - state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) - state = (state, state) + if self.recurrent_layer_type == "lstm": + # LSTM needs both hidden and cell states + single_state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) + state = (single_state, single_state) + else: + # RNN only needs hidden state + state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) if episode_start is None: episode_start = np.array([False for _ in range(n_envs)]) with th.no_grad(): # Convert to PyTorch tensors - states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor( - state[1], dtype=th.float32, device=self.device - ) + if self.recurrent_layer_type == "lstm": + # LSTM: convert tuple of states + lstm_states = ( + th.tensor(state[0], dtype=th.float32, device=self.device), + th.tensor(state[1], dtype=th.float32, device=self.device), + ) + else: + # RNN: convert single state tensor + lstm_states = th.tensor(state, dtype=th.float32, device=self.device) + episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device) - actions, states = self._predict( - observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic + actions, lstm_states = self._predict( + observation, lstm_states=lstm_states, episode_starts=episode_starts, deterministic=deterministic ) - states = (states[0].cpu().numpy(), states[1].cpu().numpy()) + + # Convert states back to numpy + if self.recurrent_layer_type == "lstm": + # LSTM: convert tuple back to numpy + lstm_states = (lstm_states[0].cpu().numpy(), lstm_states[1].cpu().numpy()) + else: + # RNN: convert single tensor back to numpy + lstm_states = lstm_states.cpu().numpy() # Convert to numpy actions = actions.cpu().numpy() @@ -428,7 +472,7 @@ def predict( if not vectorized_env: actions = actions.squeeze(axis=0) - return actions, states + return actions, lstm_states class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): @@ -462,6 +506,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param recurrent_layer_type: Type of recurrent layer to use (LSTM or vanilla RNN). + By default, LSTM is used. :param n_lstm_layers: Number of LSTM layers. :param shared_lstm: Whether the LSTM is shared between the actor and the critic. By default, only the actor has a recurrent network. @@ -489,6 +535,7 @@ def __init__( normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[dict[str, Any]] = None, + recurrent_layer_type: str = "lstm", lstm_hidden_size: int = 256, n_lstm_layers: int = 1, shared_lstm: bool = False, @@ -513,6 +560,7 @@ def __init__( normalize_images, optimizer_class, optimizer_kwargs, + recurrent_layer_type, lstm_hidden_size, n_lstm_layers, shared_lstm, @@ -552,6 +600,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param recurrent_layer_type: Type of recurrent layer to use (LSTM or vanilla RNN). + By default, LSTM is used. :param n_lstm_layers: Number of LSTM layers. :param shared_lstm: Whether the LSTM is shared between the actor and the critic. By default, only the actor has a recurrent network. @@ -579,6 +629,7 @@ def __init__( normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[dict[str, Any]] = None, + recurrent_layer_type: str = "lstm", lstm_hidden_size: int = 256, n_lstm_layers: int = 1, shared_lstm: bool = False, @@ -603,6 +654,7 @@ def __init__( normalize_images, optimizer_class, optimizer_kwargs, + recurrent_layer_type, lstm_hidden_size, n_lstm_layers, shared_lstm, diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index 17b9bfca..4a84dd48 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -5,8 +5,8 @@ class RNNStates(NamedTuple): - pi: tuple[th.Tensor, ...] - vf: tuple[th.Tensor, ...] + pi: tuple[th.Tensor, ...] | th.Tensor + vf: tuple[th.Tensor, ...] | th.Tensor class RecurrentRolloutBufferSamples(NamedTuple): diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index aaf2756a..125575a4 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -159,17 +159,27 @@ def _setup_model(self) -> None: raise ValueError("Policy must subclass RecurrentActorCriticPolicy") single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) - # hidden and cell states for actor and critic - self._last_lstm_states = RNNStates( - ( - th.zeros(single_hidden_state_shape, device=self.device), - th.zeros(single_hidden_state_shape, device=self.device), - ), - ( + + recurrent_layer_type = self.policy_kwargs.get("recurrent_layer_type", "lstm").lower() + + if recurrent_layer_type == "lstm": + # LSTM: (hidden_state, cell_state) tuples + self._last_lstm_states = RNNStates( + ( + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), + ), + ( + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), + ), + ) + else: + # RNN: single hidden state tensors + self._last_lstm_states = RNNStates( th.zeros(single_hidden_state_shape, device=self.device), th.zeros(single_hidden_state_shape, device=self.device), - ), - ) + ) hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) @@ -182,6 +192,7 @@ def _setup_model(self) -> None: gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, + recurrent_layer_type=recurrent_layer_type, ) # Initialize schedules for policy/value clipping @@ -275,11 +286,17 @@ def collect_rollouts( ): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): - terminal_lstm_state = ( - lstm_states.vf[0][:, idx : idx + 1, :].contiguous(), - lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), - ) - # terminal_lstm_state = None + # Handle both LSTM and RNN states + if self.policy.recurrent_layer_type == "lstm": + # LSTM case: (hidden, cell) tuple + terminal_lstm_state = ( + lstm_states.vf[0][:, idx : idx + 1, :].contiguous(), + lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), + ) + else: + # RNN case: single tensor + terminal_lstm_state = lstm_states.vf[:, idx : idx + 1, :].contiguous() + episode_starts = th.tensor([False], dtype=th.float32, device=self.device) terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] rewards[idx] += self.gamma * terminal_value diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e505ea59..d857383d 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -166,11 +166,15 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): @pytest.mark.parametrize("model_class", [TRPO, MaskablePPO, RecurrentPPO, QRDQN, TQC]) @pytest.mark.parametrize("normalize_images", [True, False]) -def test_image_like_input(model_class, normalize_images): +@pytest.mark.parametrize("recurrent_layer_type", ["lstm", "rnn"]) +def test_image_like_input(model_class, normalize_images, recurrent_layer_type): """ Check that we can handle image-like input (3D tensor) when normalize_images=False """ + # Skip RNN test for non-recurrent models + if model_class != RecurrentPPO and recurrent_layer_type == "rnn": + pytest.skip("RNN only applicable to RecurrentPPO") # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution # to check that the network handle it automatically @@ -199,6 +203,11 @@ def action_mask_fn(env): ), seed=1, ) + + # Add recurrent layer type for RecurrentPPO + if model_class == RecurrentPPO: + kwargs["policy_kwargs"]["recurrent_layer_type"] = recurrent_layer_type + policy = "CnnLstmPolicy" if model_class == RecurrentPPO else "CnnPolicy" if model_class in {TRPO, MaskablePPO, RecurrentPPO}: diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 003dbcbb..320529c5 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -87,6 +87,9 @@ def test_env(): lstm_hidden_size=4, share_features_extractor=False, ), + dict(recurrent_layer_type="rnn"), + dict(recurrent_layer_type="rnn", lstm_hidden_size=8), + dict(recurrent_layer_type="rnn", enable_critic_lstm=False), ], ) def test_cnn(policy_kwargs): @@ -119,6 +122,8 @@ def test_cnn(policy_kwargs): lstm_kwargs=dict(dropout=0.5), n_lstm_layers=2, ), + dict(recurrent_layer_type="rnn"), + dict(recurrent_layer_type="rnn", lstm_hidden_size=8, n_lstm_layers=2), ], ) def test_policy_kwargs(policy_kwargs): @@ -198,6 +203,8 @@ def test_run_sde(): lstm_kwargs=dict(dropout=0.5), n_lstm_layers=2, ), + dict(recurrent_layer_type="rnn"), + dict(recurrent_layer_type="rnn", lstm_hidden_size=8), ], ) def test_dict_obs(policy_kwargs):