diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 3fbd28d8..47aab3f3 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -3,6 +3,7 @@ from sb3_contrib.ars import ARS from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO +from sb3_contrib.ppo_attention import AttentionPPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC from sb3_contrib.trpo import TRPO @@ -16,6 +17,7 @@ "ARS", "MaskablePPO", "RecurrentPPO", + "AttentionPPO", "QRDQN", "TQC", "TRPO", diff --git a/sb3_contrib/common/attention/__init__.py b/sb3_contrib/common/attention/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sb3_contrib/common/attention/buffers.py b/sb3_contrib/common/attention/buffers.py new file mode 100644 index 00000000..0eee7fc4 --- /dev/null +++ b/sb3_contrib/common/attention/buffers.py @@ -0,0 +1,390 @@ +from functools import partial +from typing import Callable, Generator, Optional, Tuple, Union + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.vec_env import VecNormalize + +from sb3_contrib.common.attention.type_aliases import ( + AttentionDictRolloutBufferSamples, + AttentionRolloutBufferSamples, + AttnMemory, +) + + +def pad( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Chunk sequences and pad them to have constant dimensions. + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device + :param tensor: Tensor of shape (batch_size, *tensor_shape) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq, max_length, *tensor_shape) + """ + # Create sequences given start and end + seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] + return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value) + + +def pad_and_flatten( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Pad and flatten the sequences of scalar values, + while keeping the sequence order. + From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device (cpu, gpu, ...) + :param tensor: Tensor of shape (max_length, n_seq, 1) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq * max_length,) aka (padded_batch_size,) + """ + return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() + + +def create_sequencers( + episode_starts: np.ndarray, + env_change: np.ndarray, + device: th.device, +) -> Tuple[np.ndarray, Callable, Callable]: + """ + Create the utility function to chunk data into + sequences and pad them to create fixed size tensors. + + :param episode_starts: Indices where an episode starts + :param env_change: Indices where the data collected + come from a different env (when using multiple env for data collection) + :param device: PyTorch device + :return: Indices of the transitions that start a sequence, + pad and pad_and_flatten utilities tailored for this batch + (sequence starts and ends indices are fixed) + """ + # Create sequence if env changes too + seq_start = np.logical_or(episode_starts, env_change).flatten() + # First index is always the beginning of a sequence + seq_start[0] = True + # Retrieve indices of sequence starts + seq_start_indices = np.where(seq_start == True)[0] # noqa: E712 + # End of sequence are just before sequence starts + # Last index is also always end of a sequence + seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])]) + + # Create padding method for this minibatch + # to avoid repeating arguments (seq_start_indices, seq_end_indices) + local_pad = partial(pad, seq_start_indices, seq_end_indices, device) + local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device) + return seq_start_indices, local_pad, local_pad_and_flatten + + +class AttentionRolloutBuffer(RolloutBuffer): + """ + Rollout buffer that also stores the memory for attention network (GTrXL). + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param memory_shape: Shape of the buffer that will collect the attention memory + (n_steps, num_layers, n_envs, attention_dim) + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + memory_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.memory_shape = memory_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) + + def reset(self): + super().reset() + self.memory_pi = np.zeros(self.memory_shape, dtype=np.float32) + # self.cell_states_pi = np.zeros(self.memory_shape, dtype=np.float32) + # self.memory_vf = np.zeros(self.memory_shape, dtype=np.float32) + # self.cell_states_vf = np.zeros(self.memory_shape, dtype=np.float32) + + def add(self, *args, attn_memory: AttnMemory, **kwargs) -> None: + """ + :param attn_memory: Attention memory + """ + self.memory_pi[self.pos] = np.array(attn_memory.pi.cpu().numpy()) + # self.cell_states_pi[self.pos] = np.array(attn_memory.pi[1].cpu().numpy()) + # self.memory_vf[self.pos] = np.array(attn_memory.vf[0].cpu().numpy()) + # self.cell_states_vf[self.pos] = np.array(attn_memory.vf[1].cpu().numpy()) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[AttentionRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # memory_shape = (n_steps, num_layers, n_envs, attention_dim) + # swap first to (n_steps, n_envs, num_layers, attention_dim) + for tensor in ["memory_pi"]:#, "cell_states_pi"]:#, "memory_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + # flatten but keep the sequence order + # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) + # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "memory_pi", + # "cell_states_pi", + # "memory_vf", + # "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Sampling strategy that allows any mini batch size but requires + # more complexity and use of padding + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> AttentionRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + # Number of sequences + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + # print('n_seq', n_seq, 'max_length', max_length, 'buffer_size', self.buffer_size, 'batch_inds', len(batch_inds)) + padded_batch_size = n_seq * max_length + # Retrieving attention memory that will allow + # to properly initialize the GTrXL at the beginning of each sequence + # attn_memory_pi = ( + # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim) + # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim) + # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim) + # self.memory_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + # self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + # ) + + # print(self.memory_pi.shape, batch_inds, self.seq_start_indices) + attn_memory_pi = self.to_torch(self.memory_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1)).contiguous() + # attn_memory_vf = ( + # # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + # self.memory_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + # self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + # ) + # attn_memory_pi = (self.to_torch(attn_memory_pi[0]).contiguous(), self.to_torch(attn_memory_pi[1]).contiguous()) + # attn_memory_vf = (self.to_torch(attn_memory_vf[0]).contiguous(), self.to_torch(attn_memory_vf[1]).contiguous()) + # print('padded_batch_size', padded_batch_size, 'obs', self.obs_shape) + # print('attn_memory_pi', attn_memory_pi.size()) + return AttentionRolloutBufferSamples( + # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) + observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)), + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + attn_memory=AttnMemory(attn_memory_pi), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + ) + + +class AttentionDictRolloutBuffer(DictRolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the AttentionRolloutBuffer to use dictionary observations + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param memory_shape: Shape of the buffer that will collect memory + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + memory_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.memory_shape = memory_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) + + def reset(self): + super().reset() + self.memory_pi = np.zeros(self.memory_shape, dtype=np.float32) + # self.cell_states_pi = np.zeros(self.memory_shape, dtype=np.float32) + # self.memory_vf = np.zeros(self.memory_shape, dtype=np.float32) + # self.cell_states_vf = np.zeros(self.memory_shape, dtype=np.float32) + + def add(self, *args, attn_memory: AttnMemory, **kwargs) -> None: + """ + :param attn_memory: Attention memory + """ + self.memory_pi[self.pos] = np.array(attn_memory.pi.cpu().numpy()) + # self.cell_states_pi[self.pos] = np.array(attn_memory.pi[1].cpu().numpy()) + # self.memory_vf[self.pos] = np.array(attn_memory.vf[0].cpu().numpy()) + # self.cell_states_vf[self.pos] = np.array(attn_memory.vf[1].cpu().numpy()) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[AttentionDictRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # memory_shape = (n_steps, num_layers, n_envs, attention_dim) + # swap first to (n_steps, n_envs, num_layers, attention_dim) + for tensor in ["memory_pi"]:#, "cell_states_pi"]:#, "memory_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + for tensor in [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "memory_pi", + # "cell_states_pi", + # "memory_vf", + # "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> AttentionDictRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # Retrieving attention memory that will allow + # to properly initialize the GTrXL at the beginning of each sequence + # attn_memory_pi = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + # self.memory_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + # self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + # ) + attn_memory_pi = self.memory_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1) + # attn_memory_vf = ( + # # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + # self.memory_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + # self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + # ) + # attn_memory_pi = (self.to_torch(attn_memory_pi[0]).contiguous(), self.to_torch(attn_memory_pi[1]).contiguous()) + # attn_memory_vf = (self.to_torch(attn_memory_vf[0]).contiguous(), self.to_torch(attn_memory_vf[1]).contiguous()) + + observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} + observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} + + return AttentionDictRolloutBufferSamples( + observations=observations, + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + attn_memory=AttnMemory(attn_memory_pi), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + ) diff --git a/sb3_contrib/common/attention/policies.py b/sb3_contrib/common/attention/policies.py new file mode 100644 index 00000000..6e219e66 --- /dev/null +++ b/sb3_contrib/common/attention/policies.py @@ -0,0 +1,625 @@ +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.distributions import Distribution +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import zip_strict +from torch import nn + +from sb3_contrib.common.attention.type_aliases import AttnMemory +from sb3_contrib.ppo_attention.architecture import GTrXLNet + + +class AttentionActorCriticPolicy(ActorCriticPolicy): + """ + Attention policy class for actor-critic algorithms (has both policy and value prediction). + To be used with A2C, PPO and the likes. + It assumes that both the actor and the critic GTrXL + have the same model. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_layers: Number of layers (MHA + Position-wise MLP) + :param attention_dim: Dimension of the attention latent space + :param num_heads: Number of heads of the MHA + :param memory_inference: (not used) + :param memory_training: (not used) + :param head_dim: Heads dimension of the MHA + :param position_wise_mlp_dim: Dimension of the Position-wise MLP + :param init_gru_gate_bias: Bias initialization of the GRU gates + :param device: PyTorch device. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + + n_layers: int = 1, + attention_dim: int = 64, + n_heads: int = 2, + memory_inference: int = 50, + memory_training: int = 50, + head_dim: int = 32, + position_wise_mlp_dim: int = 32, + init_gru_gate_bias: float = 2.0, + ): + self.attention_dim = attention_dim + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + self.n_layers = n_layers + self.n_heads = n_heads + # Same model for actor and critic + self.model = GTrXLNet( + feature_dim=self.features_dim, + n_layers=n_layers, + attention_dim=attention_dim, + num_heads=n_heads, + memory_inference=memory_inference, + memory_training=memory_training, + head_dim=head_dim, + position_wise_mlp_dim=position_wise_mlp_dim, + init_gru_gate_bias=init_gru_gate_bias, + device=self.device, + ) + # For the predict() method, to initialize attention memory + # (n_layers, batch_size, attention_dim) + self.memory_shape = (n_layers, 1, attention_dim) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + self.mlp_extractor = MlpExtractor( + self.attention_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + @staticmethod + def _process_sequence( + features: th.Tensor, + attn_memory: th.Tensor,#Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + model: GTrXLNet, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Do a forward pass in the GTrXL network. + + :param features: Input tensor + :param attn_memory: previous attention memory of the GTrXL + :param episode_starts: Indicates when a new episode starts, + in that case, we need to reset the attention. + :param model: GTrXL network object. + :return: GTrXL output and updated GTrXL memory. + """ + # (sequence length, batch size, features dim) + # (batch size = n_envs for data collection or n_seq when doing gradient update) + n_seq = attn_memory.shape[1] + # Batch to sequence + # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) + # note: max length (max sequence length) is always 1 during data collection + features_sequence = features.reshape((n_seq, -1, model.input_size)).swapaxes(0, 1) + episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1) + + # If we don't have to reset the memory in the middle of a sequence + # we can avoid the for loop, which speeds up things + # if th.all(episode_starts == 0.0): + # attn_output, attn_memory = model(features_sequence, attn_memory) + # # attn_output = th.flatten(attn_output.transpose(0, 1), start_dim=0, end_dim=1) + # return attn_output, attn_memory + + outputs = [] + # Iterate over the sequence + # print('features_sequence', features_sequence.size()) + for features, episode_start in zip_strict(features_sequence, episode_starts): + out, attn_memory = model( + features.unsqueeze(dim=0), + ( + # Reset the memory at the beginning of a new episode + (1.0 - episode_start).view(1, -1, 1) * attn_memory + ), + ) + outputs += [out] + # Sequence to batch + # (sequence length, n_seq, out_dim) -> (batch_size, out_dim) + outputs = th.cat(outputs) + #attn_output = th.flatten(th.cat(attn_output).transpose(0, 1), start_dim=0, end_dim=1) + return outputs, attn_memory + + def forward( + self, + obs: th.Tensor, + attn_memory: AttnMemory, + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, AttnMemory]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation. Observation + :param attn_memory: The last attention memory for the GTrXL model. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the attention memory in that case). + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + # if self.share_features_extractor: + # pi_features = vf_features = features # alis + # else: + # pi_features, vf_features = features + pi_features = vf_features = features # alis + # latent_pi, latent_vf = self.mlp_extractor(features) + latent_pi, attn_memory_pi = self._process_sequence(pi_features, attn_memory.pi, episode_starts, self.model) + # if self.model_critic is not None: + # latent_vf, attn_memory_vf = self._process_sequence(vf_features, attn_memory.vf, episode_starts, self.model_critic) + # elif self.shared_model: + # # Re-use GTrXL features but do not backpropagate + # latent_vf = latent_pi.detach() + # attn_memory_vf = (attn_memory_pi[0].detach(), attn_memory_pi[1].detach()) + # else: + # # Critic only has a feedforward network + # latent_vf = self.critic(vf_features) + # attn_memory_vf = attn_memory_pi + + #print('out:', latent_pi.size(), 'attn_memory:', attn_memory_pi.size()) + # attn_memory_vf = attn_memory_pi.detach() + latent_vf = latent_pi.detach() + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + # Evaluate the values for the given observations + values = self.value_net(latent_pi) # latent_vf + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob, AttnMemory(attn_memory_pi) + + def get_distribution( + self, + obs: th.Tensor, + attn_memory: th.Tensor,#Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> Tuple[Distribution, Tuple[th.Tensor, ...]]: + """ + Get the current policy distribution given the observations. + + :param obs: Observation. + :param attn_memory: The last attention memory for the GTrXL model. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the attention memory in that case). + :return: the action distribution and new memory. + """ + # Call the method from the parent of the parent class + features = super(ActorCriticPolicy, self).extract_features(obs, self.pi_features_extractor) + latent_pi, attn_memory = self._process_sequence(features, attn_memory, episode_starts, self.model) + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + return self._get_action_dist_from_latent(latent_pi), attn_memory + + def predict_values( + self, + obs: th.Tensor, + attn_memory: th.Tensor,#Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: Observation. + :param attn_memory: The attention memory for the GTrXL model. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the attention memory in that case). + :return: the estimated values. + """ + # Call the method from the parent of the parent class + features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor) + + # if self.model_critic is not None: + # latent_vf, attn_memory = self._process_sequence(features, attn_memory, episode_starts, self.model_critic) + # elif self.shared_model: + # # Use GTrXL from the actor + # latent_pi, _ = self._process_sequence(features, attn_memory, episode_starts, self.model_actor) + # latent_vf = latent_pi.detach() + # else: + # latent_vf = self.critic(features) + latent_pi, _ = self._process_sequence(features, attn_memory, episode_starts, self.model) + latent_vf = latent_pi.detach() + + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + return self.value_net(latent_vf) + + def evaluate_actions( + self, obs: th.Tensor, actions: th.Tensor, attn_memory: AttnMemory, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: Observation. + :param actions: + :param attn_memory: The last attention memory for the GTrXL model. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the attention memory in that case). + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + #print('OBS', obs.size()) + # Preprocess the observation if needed + features = self.extract_features(obs) + #print('FEATURES', features.size()) + # if self.share_features_extractor: + # pi_features = vf_features = features # alias + # else: + # pi_features, vf_features = features + pi_features = vf_features = features # alias + #attn_memory = th.tensor(attn_memory.pi, dtype=th.float32, device=self.device) + latent_pi, _ = self._process_sequence(pi_features, attn_memory.pi, episode_starts, self.model) + # if self.model_critic is not None: + # latent_vf, _ = self._process_sequence(vf_features, attn_memory.vf, episode_starts, self.model_critic) + # elif self.shared_model: + # latent_vf = latent_pi.detach() + # else: + # latent_vf = self.critic(vf_features) + #latent_vf = latent_pi.detach() + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + #latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_pi) #values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def _predict( + self, + observation: th.Tensor, + attn_memory: th.Tensor, #Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, Tuple[th.Tensor, ...]]: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param attn_memory: The last attention memory for the GTrXL model. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the attention memory in that case). + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy and memory of the Attention network + """ + distribution, attn_memory = self.get_distribution(observation, attn_memory, episode_starts) + return distribution.get_actions(deterministic=deterministic), attn_memory + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + memory: Optional[np.ndarray] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional attention memory). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param memory: The last attention memory for the GTrXL. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the attention memory in that case). + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next memory + (used in attention policies) + """ + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + if isinstance(observation, dict): + n_envs = observation[list(observation.keys())[0]].shape[0] + else: + n_envs = observation.shape[0] + # memory : (n_layers, n_envs, dim) + if memory is None: + # Initialize memory to zeros + memory = th.tensor(np.concatenate([np.zeros(self.memory_shape) for _ in range(n_envs)], axis=1), + dtype=th.float32, device=self.device) + + if episode_start is None: + episode_start = np.array([False for _ in range(n_envs)]) + + with th.no_grad(): + # Convert to PyTorch tensors + episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device) + actions, memory = self._predict( + observation, attn_memory=memory, episode_starts=episode_starts, deterministic=deterministic + ) + + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions.squeeze(axis=0) + + return actions, memory + + +class AttentionActorCriticCnnPolicy(AttentionActorCriticPolicy): + """ + CNN attention policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_layers: Number of layers (MHA + Position-wise MLP) + :param attention_dim: Dimension of the attention latent space + :param num_heads: Number of heads of the MHA + :param memory_inference: (not used) + :param memory_training: (not used) + :param head_dim: Heads dimension of the MHA + :param position_wise_mlp_dim: Dimension of the Position-wise MLP + :param init_gru_gate_bias: Bias initialization of the GRU gates + :param device: PyTorch device. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + + n_layers: int = 1, + attention_dim: int = 64, + n_heads: int = 2, + memory_inference: int = 50, + memory_training: int = 50, + head_dim: int = 32, + position_wise_mlp_dim: int = 32, + init_gru_gate_bias: float = 2.0, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_layers, + attention_dim, + n_heads, + memory_inference, + memory_training, + head_dim, + position_wise_mlp_dim, + init_gru_gate_bias, + ) + + +class AttentionMultiInputActorCriticPolicy(AttentionActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_layers: Number of layers (MHA + Position-wise MLP) + :param attention_dim: Dimension of the attention latent space + :param num_heads: Number of heads of the MHA + :param memory_inference: (not used) + :param memory_training: (not used) + :param head_dim: Heads dimension of the MHA + :param position_wise_mlp_dim: Dimension of the Position-wise MLP + :param init_gru_gate_bias: Bias initialization of the GRU gates + :param device: PyTorch device. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + + n_layers: int = 1, + attention_dim: int = 64, + n_heads: int = 2, + memory_inference: int = 50, + memory_training: int = 50, + head_dim: int = 32, + position_wise_mlp_dim: int = 32, + init_gru_gate_bias: float = 2.0, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_layers, + attention_dim, + n_heads, + memory_inference, + memory_training, + head_dim, + position_wise_mlp_dim, + init_gru_gate_bias, + ) diff --git a/sb3_contrib/common/attention/type_aliases.py b/sb3_contrib/common/attention/type_aliases.py new file mode 100644 index 00000000..0ee276ae --- /dev/null +++ b/sb3_contrib/common/attention/type_aliases.py @@ -0,0 +1,33 @@ +from typing import NamedTuple, Tuple + +import torch as th +from stable_baselines3.common.type_aliases import TensorDict + + +class AttnMemory(NamedTuple): + pi: th.Tensor #Tuple[th.Tensor, ...] + # vf: Tuple[th.Tensor, ...] + + +class AttentionRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + attn_memory: AttnMemory + episode_starts: th.Tensor + mask: th.Tensor + + +class AttentionDictRolloutBufferSamples(NamedTuple): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + attn_memory: AttnMemory + episode_starts: th.Tensor + mask: th.Tensor diff --git a/sb3_contrib/ppo_attention/__init__.py b/sb3_contrib/ppo_attention/__init__.py new file mode 100644 index 00000000..7a06ac37 --- /dev/null +++ b/sb3_contrib/ppo_attention/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.ppo_attention.policies import CnnAttnPolicy, MlpAttnPolicy, MultiInputAttnPolicy +from sb3_contrib.ppo_attention.ppo_attention import AttentionPPO + +__all__ = ["CnnAttnPolicy", "MlpAttnPolicy", "MultiInputAttnPolicy", "AttentionPPO"] diff --git a/sb3_contrib/ppo_attention/architecture.py b/sb3_contrib/ppo_attention/architecture.py new file mode 100644 index 00000000..36bd670b --- /dev/null +++ b/sb3_contrib/ppo_attention/architecture.py @@ -0,0 +1,474 @@ +import torch +from torch import nn +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, Callable +from stable_baselines3.common.utils import get_device + +# Code from RLlib: https://github.com/ray-project/ray/blob/master/rllib/models/torch/attention_net.py + +def sequence_mask( + lengths, + maxlen: Optional[int] = None, + dtype=None, + time_major: bool = False, +): + """Offers same behavior as tf.sequence_mask for torch. + Thanks to Dimitris Papatheodorou + (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ + 39036). + :param lengths: The tensor of individual lengths to mask by. + :param maxlen: The maximum length to use for the time axis. If None, use + the max of `lengths`. + :param dtype: The torch dtype to use for the resulting mask. + :param time_major: Whether to return the mask as [B, T] (False; default) or + as [T, B] (True). + :return: The sequence mask resulting from the given input and parameters. + """ + # If maxlen not given, use the longest lengths in the `lengths` tensor. + if maxlen is None: + maxlen = int(lengths.max()) + + mask = ~( + torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() + > lengths + ) + # Time major transformation. + if not time_major: + mask = mask.t() + + # By default, set the mask to be boolean. + mask.type(dtype or torch.bool) + + return mask + +class SlimFC(nn.Module): + """Simple PyTorch version of `linear` function""" + + def __init__( + self, + in_size: int, + out_size: int, + initializer: Any = None, + activation_fn: Any = None, + use_bias: bool = True, + bias_init: float = 0.0, + ): + """Creates a standard FC layer, similar to torch.nn.Linear + :param in_size: Input size for FC Layer + :param out_size: Output size for FC Layer + :param initializer: Initializer function for FC layer weights + :param activation_fn: Activation function at the end of layer + :param use_bias: Whether to add bias weights or not + :param bias_init: Initalize bias weights to bias_init const + """ + super(SlimFC, self).__init__() + layers = [] + # Actual nn.Linear layer (including correct initialization logic). + linear = nn.Linear(in_size, out_size, bias=use_bias) + if initializer is None: + initializer = nn.init.xavier_uniform_ + initializer(linear.weight) + if use_bias is True: + nn.init.constant_(linear.bias, bias_init) + layers.append(linear) + # Activation function (if any; default=None (linear)). + if activation_fn is not None: + layers.append(activation_fn()) + # Put everything in sequence. + self._model = nn.Sequential(*layers) + + def forward(self, x): + return self._model(x) + + +class GRUGate(nn.Module): + """Implements a gated recurrent unit for use in AttentionNet""" + + def __init__(self, dim: int, init_bias: int = 0.0, **kwargs): + """ + :param input_shape (torch.Tensor): dimension of the input + :param init_bias: Bias added to every input to stabilize training + """ + super().__init__(**kwargs) + # Xavier initialization of torch tensors + self._w_r = nn.Parameter(torch.zeros(dim, dim)) + self._w_z = nn.Parameter(torch.zeros(dim, dim)) + self._w_h = nn.Parameter(torch.zeros(dim, dim)) + nn.init.xavier_uniform_(self._w_r) + nn.init.xavier_uniform_(self._w_z) + nn.init.xavier_uniform_(self._w_h) + self.register_parameter("_w_r", self._w_r) + self.register_parameter("_w_z", self._w_z) + self.register_parameter("_w_h", self._w_h) + + self._u_r = nn.Parameter(torch.zeros(dim, dim)) + self._u_z = nn.Parameter(torch.zeros(dim, dim)) + self._u_h = nn.Parameter(torch.zeros(dim, dim)) + nn.init.xavier_uniform_(self._u_r) + nn.init.xavier_uniform_(self._u_z) + nn.init.xavier_uniform_(self._u_h) + self.register_parameter("_u_r", self._u_r) + self.register_parameter("_u_z", self._u_z) + self.register_parameter("_u_h", self._u_h) + + self._bias_z = nn.Parameter( + torch.zeros( + dim, + ).fill_(init_bias) + ) + self.register_parameter("_bias_z", self._bias_z) + + def forward(self, inputs, **kwargs): + # Pass in internal state first. + h, X = inputs + + r = torch.tensordot(X, self._w_r, dims=1) + torch.tensordot( + h, self._u_r, dims=1 + ) + r = torch.sigmoid(r) + + z = ( + torch.tensordot(X, self._w_z, dims=1) + + torch.tensordot(h, self._u_z, dims=1) + - self._bias_z + ) + z = torch.sigmoid(z) + + h_next = torch.tensordot(X, self._w_h, dims=1) + torch.tensordot( + (h * r), self._u_h, dims=1 + ) + h_next = torch.tanh(h_next) + + return (1 - z) * h + z * h_next + + +class SkipConnection(nn.Module): + """Skip connection layer. + Adds the original input to the output (regular residual layer) OR uses + input as hidden state input to a given fan_in_layer. + """ + + def __init__(self, layer: nn.Module, fan_in_layer: Optional[nn.Module] = None, **kwargs): + """Initializes a SkipConnection nn Module object. + :param layer (nn.Module): Any layer processing inputs. + :param fan_in_layer (Optional[nn.Module]): An optional + layer taking two inputs: The original input and the output + of `layer`. + """ + super().__init__(**kwargs) + self._layer = layer + self._fan_in_layer = fan_in_layer + + def forward(self, inputs, **kwargs): + # del kwargs + outputs = self._layer(inputs, **kwargs) + # Residual case, just add inputs to outputs. + if self._fan_in_layer is None: + outputs = outputs + inputs + # Fan-in e.g. RNN: Call fan-in with `inputs` and `outputs`. + else: + # NOTE: In the GRU case, `inputs` is the state input. + outputs = self._fan_in_layer((inputs, outputs)) + + return outputs + + +class RelativePositionEmbedding(nn.Module): + """Creates a [seq_length x seq_length] matrix for rel. pos encoding. + Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding + matrix. + :param seq_length: The max. sequence length (time axis). + :param out_dim: The number of nodes to go into the first Tranformer + layer with. + :return: The encoding matrix Phi. + """ + + def __init__(self, out_dim, **kwargs): + super().__init__() + self.out_dim = out_dim + + out_range = torch.arange(0, self.out_dim, 2.0) + inverse_freq = 1 / (10000 ** (out_range / self.out_dim)) + self.register_buffer("inverse_freq", inverse_freq) + + def forward(self, seq_length): + pos_input = torch.arange(seq_length - 1, -1, -1.0, dtype=torch.float).to( + self.inverse_freq.device + ) + sinusoid_input = torch.einsum("i,j->ij", pos_input, self.inverse_freq) + pos_embeddings = torch.cat( + [torch.sin(sinusoid_input), torch.cos(sinusoid_input)], dim=-1 + ) + return pos_embeddings[:, None, :] + + +class RelativeMultiHeadAttention(nn.Module): + """A RelativeMultiHeadAttention layer as described in [3]. + Uses segment level recurrence with state reuse. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_heads: int, + head_dim: int, + input_layernorm: bool = False, + output_activation: Union[str, callable] = None, + **kwargs + ): + """Initializes a RelativeMultiHeadAttention nn.Module object. + :param in_dim (int): + :param out_dim: The output dimension of this module. Also known as + "attention dim". + :param num_heads: The number of attention heads to use. + Denoted `H` in [2]. + :param head_dim: The dimension of a single(!) attention head + Denoted `D` in [2]. + :param input_layernorm: Whether to prepend a LayerNorm before + everything else. Should be True for building a GTrXL. + :param output_activation (Union[str, callable]): Optional activation + function or activation function specifier (str). + Should be "relu" for GTrXL. + :param **kwargs: + """ + super().__init__(**kwargs) + + # No bias or non-linearity. + self._num_heads = num_heads + self._head_dim = head_dim + + # 3=Query, key, and value inputs. + self._qkv_layer = SlimFC( + in_size=in_dim, out_size=3 * num_heads * head_dim, use_bias=False + ) + + self._linear_layer = SlimFC( + in_size=num_heads * head_dim, + out_size=out_dim, + use_bias=False, + activation_fn=output_activation, + ) + + self._uvar = nn.Parameter(torch.zeros(num_heads, head_dim)) + self._vvar = nn.Parameter(torch.zeros(num_heads, head_dim)) + nn.init.xavier_uniform_(self._uvar) + nn.init.xavier_uniform_(self._vvar) + self.register_parameter("_uvar", self._uvar) + self.register_parameter("_vvar", self._vvar) + + self._pos_proj = SlimFC( + in_size=in_dim, out_size=num_heads * head_dim, use_bias=False + ) + self._rel_pos_embedding = RelativePositionEmbedding(out_dim) + + self._input_layernorm = None + if input_layernorm: + self._input_layernorm = torch.nn.LayerNorm(in_dim) + + #print('in_dim', in_dim) + + def forward(self, inputs, memory=None): + T = inputs.shape[1] #list(inputs.size())[1] # length of segment (time) + H = self._num_heads # number of attention heads + d = self._head_dim # attention head dimension + + # Add previous memory chunk (as const, w/o gradient) to input. + # Tau (number of (prev) time slices in each memory chunk). + Tau = memory.shape[1] #list(memory.shape)[1] + inputs = torch.cat((memory.detach(), inputs), dim=1) + + # Apply the Layer-Norm. + if self._input_layernorm is not None: + inputs = self._input_layernorm(inputs) + + qkv = self._qkv_layer(inputs) + + queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) + # Cut out Tau memory timesteps from query. + #if memory is not None: + queries = queries[:, -T:] + + queries = torch.reshape(queries, [-1, T, H, d]) + keys = torch.reshape(keys, [-1, Tau + T, H, d]) + values = torch.reshape(values, [-1, Tau + T, H, d]) + + R = self._pos_proj(self._rel_pos_embedding(Tau + T)) + R = torch.reshape(R, [Tau + T, H, d]) + + # b=batch + # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) + # h=head + # d=head-dim (over which we will reduce-sum) + score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) + pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R) + score = score + self.rel_shift(pos_score) + score = score / d**0.5 + + # causal mask of the same length as the sequence + mask = sequence_mask(torch.arange(Tau + 1, Tau + T + 1), dtype=score.dtype).to(score.device) + mask = mask[None, :, :, None] + + masked_score = score * mask + 1e30 * (mask.float() - 1.0) + wmat = nn.functional.softmax(masked_score, dim=2) + + out = torch.einsum("bijh,bjhd->bihd", wmat, values) + shape = list(out.shape)[:2] + [H * d] + out = torch.reshape(out, shape) + + return self._linear_layer(out) + + @staticmethod + def rel_shift(x): + # Transposed version of the shift approach described in [3]. + # https://github.com/kimiyoung/transformer-xl/blob/ + # 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31 + x_size = list(x.shape) + + x = torch.nn.functional.pad(x, (0, 0, 1, 0, 0, 0, 0, 0)) + x = torch.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]]) + x = x[:, 1:, :, :] + x = torch.reshape(x, x_size) + + return x + + +class GTrXLNet(nn.Module): + """ + Constructs an Attention that receives the output from a previous features extractor or directly the observations (if no features extractor is applied) as an input and outputs a latent representation for the policy and a value network. + :param feature_dim: Dimension of the feature vector (can be the output of a CNN) + :param n_layers: Number of layers (MHA + Position-wise MLP) + :param attention_dim: Dimension of the attention latent space + :param num_heads: Number of heads of the MHA + :param memory_inference: (not used) + :param memory_training: (not used) + :param head_dim: Heads dimension of the MHA + :param position_wise_mlp_dim: Dimension of the Position-wise MLP + :param init_gru_gate_bias: Bias initialization of the GRU gates + :param device: PyTorch device. + """ + def __init__( + self, + feature_dim: int, + n_layers: int = 1, + attention_dim: int = 64, + num_heads: int = 2, + memory_inference: int = 50, + memory_training: int = 50, + head_dim: int = 32, + position_wise_mlp_dim: int = 32, + init_gru_gate_bias: float = 2.0, + device: Union[torch.device, str] = "auto", + ) -> None: + super().__init__() + + device = get_device(device) + self.input_size = feature_dim + self.n_layers = n_layers + self.attention_dim = attention_dim + self.num_heads = num_heads + self.memory_inference = memory_inference + self.memory_training = memory_training + self.head_dim = head_dim + + self.linear_layer = SlimFC(in_size=feature_dim, out_size=self.attention_dim) + self.layers = [self.linear_layer] + + attention_layers = [] + for i in range(self.n_layers): + # RelativeMultiHeadAttention part. + MHA_layer = SkipConnection( + RelativeMultiHeadAttention( + in_dim=self.attention_dim, + out_dim=self.attention_dim, + num_heads=num_heads, + head_dim=head_dim, + input_layernorm=True, + output_activation=nn.ReLU, + ), + fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias), + ) + + # Position-wise MultiLayerPerceptron part. + list_e_layer = [torch.nn.LayerNorm(self.attention_dim), + SlimFC( + in_size=self.attention_dim, + out_size=position_wise_mlp_dim, + use_bias=False, + activation_fn=nn.ReLU, + ), + SlimFC( + in_size=position_wise_mlp_dim, + out_size=self.attention_dim, + use_bias=False, + activation_fn=nn.ReLU, + )] + E_layer = SkipConnection( + nn.Sequential(*list_e_layer), + fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias), + ) + + # Build a list of all attanlayers in order. + attention_layers.extend([MHA_layer, E_layer]) + + # Create a Sequential such that all parameters inside the attention + # layers are automatically registered with this top-level model. + self.attention_layers = nn.Sequential(*attention_layers).to(device) + self.layers.extend(attention_layers) + + # Final layers if num_outputs not None. + self.logits = None + self.values_out = None + # Last value output. + self._value_out = None + # Postprocess GTrXL output with another hidden layer. + # if self.num_outputs is not None: + # self.logits = SlimFC( + # in_size=self.attention_dim, + # out_size=self.num_outputs, + # activation_fn=nn.ReLU, + # ) + + # # Value function used by all RLlib Torch RL implementations. + # self.values_out = SlimFC( + # in_size=self.attention_dim, out_size=1, activation_fn=None + # ) + # else: + # self.num_outputs = self.attention_dim + self.num_outputs = self.attention_dim + + + def forward(self, features: torch.Tensor, memory: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + memory_outs = [] + # print('features:', features.size(), ' | memory:', memory.size()) + for i in range(len(self.layers)): + # MHA layers which need memory passed in. + if i % 2 == 1: + features = self.layers[i](features, memory=memory[i//2].unsqueeze(0)) + # Either self.linear_layer (initial obs -> attn. dim layer) or + # MultiLayerPerceptrons. The output of these layers is always the + # memory for the next forward pass. + else: + features = self.layers[i](features) + memory_outs.append(features) + + # Discard last output (not needed as a memory since it's the last + # layer). + memory_outs = memory_outs[:-1] + + if self.logits is not None: + out = self.logits(features) + self._value_out = self.values_out(features) + out_dim = self.num_outputs + else: + out = features + out_dim = self.attention_dim + out = features + out_dim = self.attention_dim + + # print('out:', torch.reshape(out, [-1, out_dim]).size(), ' | memory_out:', torch.concat([ + # torch.reshape(m, [1, -1, self.attention_dim]) for m in memory_outs + # ], dim=0).size()) + + return torch.reshape(out, [-1, out_dim]), torch.concat([ + torch.reshape(m, [1, -1, self.attention_dim]) for m in memory_outs + ], dim=0) \ No newline at end of file diff --git a/sb3_contrib/ppo_attention/policies.py b/sb3_contrib/ppo_attention/policies.py new file mode 100644 index 00000000..31929bc1 --- /dev/null +++ b/sb3_contrib/ppo_attention/policies.py @@ -0,0 +1,9 @@ +from sb3_contrib.common.attention.policies import ( + AttentionActorCriticCnnPolicy, + AttentionActorCriticPolicy, + AttentionMultiInputActorCriticPolicy, +) + +MlpAttnPolicy = AttentionActorCriticPolicy +CnnAttnPolicy = AttentionActorCriticCnnPolicy +MultiInputAttnPolicy = AttentionMultiInputActorCriticPolicy diff --git a/sb3_contrib/ppo_attention/ppo_attention.py b/sb3_contrib/ppo_attention/ppo_attention.py new file mode 100644 index 00000000..a5acfc85 --- /dev/null +++ b/sb3_contrib/ppo_attention/ppo_attention.py @@ -0,0 +1,509 @@ +import sys +import time +from copy import deepcopy +from typing import Any, Dict, Optional, Type, TypeVar, Union + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean +from stable_baselines3.common.vec_env import VecEnv + +from sb3_contrib.common.attention.buffers import AttentionDictRolloutBuffer, AttentionRolloutBuffer +from sb3_contrib.common.attention.policies import AttentionActorCriticPolicy +from sb3_contrib.common.attention.type_aliases import AttnMemory +from sb3_contrib.ppo_attention.policies import CnnAttnPolicy, MlpAttnPolicy, MultiInputAttnPolicy + +SelfAttentionPPO = TypeVar("SelfAttentionPPO", bound="AttentionPPO") + + +class AttentionPPO(OnPolicyAlgorithm): + """ + Proximal Policy Optimization algorithm (PPO) (clip version) + with support for attention policies (GTrXL). + + Based on the original Stable Baselines 3 implementation. + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average + the reported success rate, mean episode length, and mean reward over + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpAttnPolicy": MlpAttnPolicy, + "CnnAttnPolicy": CnnAttnPolicy, + "MultiInputAttnPolicy": MultiInputAttnPolicy, + } + + def __init__( + self, + policy: Union[str, Type[AttentionActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 128, + batch_size: Optional[int] = 128, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: Optional[float] = None, + stats_window_size: int = 100, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + stats_window_size=stats_window_size, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + seed=seed, + device=device, + _init_setup_model=False, + supported_action_spaces=None, + ) + """ + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + """ + + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert buffer_size > 1 or ( + not normalize_advantage + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + self._last_attn_memory = None + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + buffer_cls = AttentionDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else AttentionRolloutBuffer + + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + use_sde=self.use_sde, + **self.policy_kwargs, # pytype:disable=not-instantiable + ) + # We assume the same architecture for the actor and the critic + self.policy = self.policy.to(self.device) + + if not isinstance(self.policy, AttentionActorCriticPolicy): + raise ValueError("Policy must subclass AttentionActorCriticPolicy") + + single_memory_shape = (self.policy.n_layers, self.n_envs, self.policy.attention_dim) + # attention memory for actor + self._last_attn_memory = AttnMemory( + th.zeros(single_memory_shape, device=self.device), + # ( + # th.zeros(single_memory_shape, device=self.device), + # th.zeros(single_memory_shape, device=self.device), + # ), + # ( + # th.zeros(single_memory_shape, device=self.device), + # th.zeros(single_memory_shape, device=self.device), + # ), + ) + + # print('n_steps', self.n_steps) + + hidden_memory_buffer_shape = (self.n_steps, self.policy.n_layers, self.n_envs, self.policy.attention_dim) + + self.rollout_buffer = buffer_cls( + self.n_steps, + self.observation_space, + self.action_space, + hidden_memory_buffer_shape, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + + # Initialize schedules for policy/value clipping + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, pass `None` to deactivate vf clipping" + + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert isinstance( + rollout_buffer, (AttentionRolloutBuffer, AttentionDictRolloutBuffer) + ), f"{rollout_buffer} doesn't support attention policy" + + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + attn_memory = deepcopy(self._last_attn_memory) + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device) + actions, values, log_probs, attn_memory = self.policy.forward(obs_tensor, attn_memory, episode_starts) + + actions = actions.cpu().numpy() + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done_ in enumerate(dones): + if ( + done_ + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_attn_memory = attn_memory.pi + # terminal_attn_memory = None + episode_starts = th.tensor([False], dtype=th.float32, device=self.device) + terminal_value = self.policy.predict_values(terminal_obs, terminal_attn_memory, episode_starts)[0] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, + actions, + rewards, + self._last_episode_starts, + values, + log_probs, + attn_memory=self._last_attn_memory, + ) + + self._last_obs = new_obs + self._last_episode_starts = dones + self._last_attn_memory = attn_memory + + with th.no_grad(): + # Compute value for the last timestep + episode_starts = th.tensor(dones, dtype=th.float32, device=self.device) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), attn_memory.pi, episode_starts) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.on_rollout_end() + + return True + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Convert mask from float to bool + mask = rollout_data.mask > 1e-8 + + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + self.policy.reset_noise(self.batch_size) + + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + rollout_data.attn_memory, + rollout_data.episode_starts, + ) + + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + # Mask padded sequences + value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) + + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob[mask]) + else: + entropy_loss = -th.mean(entropy[mask]) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self: SelfAttentionPPO, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "AttentionPPO", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfAttentionPPO: + iteration = 0 + + total_timesteps, callback = self._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_timesteps: + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) + + if continue_training is False: + break + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + if log_interval is not None and iteration % log_interval == 0: + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + self.train() + + callback.on_training_end() + + return self