diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index d7f3de72..f62d4775 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -141,7 +141,7 @@ def act(self, obs, critic_obs): self.transition.values = self.policy.evaluate(critic_obs).detach() self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach() self.transition.action_mean = self.policy.action_mean.detach() - self.transition.action_sigma = self.policy.action_std.detach() + self.transition.actions_distribution = self.policy.actions_distribution.detach() # need to record obs and critic_obs before env.step() self.transition.observations = obs self.transition.privileged_observations = critic_obs @@ -214,12 +214,11 @@ def update(self): # noqa: C901 returns_batch, old_actions_log_prob_batch, old_mu_batch, - old_sigma_batch, + old_actions_distributions_parameters, hid_states_batch, masks_batch, rnd_state_batch, ) in generator: - # number of augmentations per sample # we start with 1 and increase it if we use symmetry augmentation num_aug = 1 @@ -262,21 +261,16 @@ def update(self): # noqa: C901 # -- entropy # we only keep the entropy of the first augmentation (the original one) mu_batch = self.policy.action_mean[:original_batch_size] - sigma_batch = self.policy.action_std[:original_batch_size] + actions_distributions_batch = self.policy.actions_distribution[:original_batch_size] entropy_batch = self.policy.entropy[:original_batch_size] # KL if self.desired_kl is not None and self.schedule == "adaptive": + current_dist = self.policy.build_distribution(actions_distributions_batch) + old_dist = self.policy.build_distribution(old_actions_distributions_parameters) with torch.inference_mode(): - kl = torch.sum( - torch.log(sigma_batch / old_sigma_batch + 1.0e-5) - + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) - / (2.0 * torch.square(sigma_batch)) - - 0.5, - axis=-1, - ) + kl = torch.distributions.kl.kl_divergence(current_dist, old_dist) kl_mean = torch.mean(kl) - # Reduce the KL divergence across all GPUs if self.is_multi_gpu: torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM) @@ -304,6 +298,7 @@ def update(self): # noqa: C901 # Surrogate loss ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) + surrogate = -torch.squeeze(advantages_batch) * ratio surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp( ratio, 1.0 - self.clip_param, 1.0 + self.clip_param diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index 0a96bd93..5b360dba 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -6,6 +6,7 @@ """Definitions for neural-network components for RL-agents.""" from .actor_critic import ActorCritic +from .actor_critic_beta import ActorCriticBeta from .actor_critic_recurrent import ActorCriticRecurrent from .normalizer import EmpiricalNormalization from .rnd import RandomNetworkDistillation @@ -14,6 +15,7 @@ __all__ = [ "ActorCritic", + "ActorCriticBeta", "ActorCriticRecurrent", "EmpiricalNormalization", "RandomNetworkDistillation", diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py index 76802484..ee985aca 100644 --- a/rsl_rl/modules/actor_critic.py +++ b/rsl_rl/modules/actor_critic.py @@ -25,6 +25,8 @@ def __init__( activation="elu", init_noise_std=1.0, noise_std_type: str = "scalar", + clip_actions: bool = False, + clip_actions_range: tuple = (-1.0, 1.0), **kwargs, ): if kwargs: @@ -49,6 +51,11 @@ def __init__( actor_layers.append(activation) self.actor = nn.Sequential(*actor_layers) + self.clip_actions = clip_actions + self.clip_actions_range = clip_actions_range + if self.clip_actions: + self.clipping_layer = nn.Tanh() + # Value function critic_layers = [] critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) @@ -94,19 +101,34 @@ def forward(self): @property def action_mean(self): - return self.distribution.mean - + mode = self.distribution.mean + if self.clip_actions: + mode = ((mode + 1) /2.0)* (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0] + return mode + @property def action_std(self): return self.distribution.stddev + + @property + def actions_distribution(self) -> torch.Tensor: + # Mean and Std concatenated on an extra dimension + return torch.stack([self.distribution.mean, self.distribution.stddev], dim=-1) @property def entropy(self): return self.distribution.entropy().sum(dim=-1) + def build_distribution(self, parameters): + # build the distribution + return Normal(parameters[..., 0], parameters[..., 1]) + def update_distribution(self, observations): # compute mean mean = self.actor(observations) + if self.clip_actions: + mean = self.clipping_layer(mean) + # compute standard deviation if self.noise_std_type == "scalar": std = self.std.expand_as(mean) @@ -119,14 +141,30 @@ def update_distribution(self, observations): def act(self, observations, **kwargs): self.update_distribution(observations) - return self.distribution.sample() + act = self.distribution.sample() + if self.clip_actions: + # Apply tanh to clip the actions to [-1, 1] + act = self.clipping_layer(act) + # Rescale the actions to the desired range + act = ((act + 1) / 2.0) * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0] + return act def get_actions_log_prob(self, actions): - return self.distribution.log_prob(actions).sum(dim=-1) + # Scale the actions to [-1, 1] before computing the log probability. + if self.clip_actions: + # The unscaled actions still have the tanh applied to them. + unscaled_actions = (actions - self.clip_actions_range[0]) / (self.clip_actions_range[1] - self.clip_actions_range[0]) * 2.0 - 1.0 + # Revert the tanh to get the original actions. We use the TanhBijector to avoid numerical issues. + gaussian_actions = self.inverse_tanh(unscaled_actions) + return (self.distribution.log_prob(gaussian_actions) - torch.log(1 - unscaled_actions*unscaled_actions + 1e-6)).sum(dim=-1) + else: + return self.distribution.log_prob(actions).sum(dim=-1) def act_inference(self, observations): - actions_mean = self.actor(observations) - return actions_mean + mode= self.actor(observations) + if self.clip_actions: + mode = ((mode + 1) / 2.0) * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0] + return mode def evaluate(self, critic_observations, **kwargs): value = self.critic(critic_observations) @@ -147,3 +185,12 @@ def load_state_dict(self, state_dict, strict=True): super().load_state_dict(state_dict, strict=strict) return True + + @staticmethod + def atanh(x): + return 0.5 * (x.log1p() - (-x).log1p()) + + @staticmethod + def inverse_tanh(y): + eps = torch.finfo(y.dtype).eps + return ActorCritic.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps)) \ No newline at end of file diff --git a/rsl_rl/modules/actor_critic_beta.py b/rsl_rl/modules/actor_critic_beta.py new file mode 100644 index 00000000..e72aaa23 --- /dev/null +++ b/rsl_rl/modules/actor_critic_beta.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.distributions import Beta + +from rsl_rl.utils import resolve_nn_activation + + +class ActorCriticBeta(nn.Module): + is_recurrent = False + + def __init__( + self, + num_actor_obs, + num_critic_obs, + num_actions, + actor_hidden_dims=[256, 256, 256], + critic_hidden_dims=[256, 256, 256], + activation="elu", + init_noise_std=1.0, + noise_std_type: str = "scalar", + clip_actions: bool = True, + clip_actions_range: tuple = (-1.0, 1.0), + **kwargs, + ): + if kwargs: + print( + "ActorCritic.__init__ got unexpected arguments, which will be ignored: " + + str([key for key in kwargs.keys()]) + ) + super().__init__() + activation = resolve_nn_activation(activation) + + mlp_input_dim_a = num_actor_obs + mlp_input_dim_c = num_critic_obs + # Policy + actor_layers = [] + actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) + actor_layers.append(activation) + for layer_index in range(len(actor_hidden_dims)): + if layer_index == len(actor_hidden_dims) - 1: + self.alpha = nn.Linear(actor_hidden_dims[layer_index], num_actions) + self.beta = nn.Linear(actor_hidden_dims[layer_index], num_actions) + self.alpha_activation = nn.Softplus() + self.beta_activation = nn.Softplus() + else: + actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1])) + actor_layers.append(activation) + self.actor = nn.Sequential(*actor_layers) + + self.clip_actions_range = clip_actions_range + + # Value function + critic_layers = [] + critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) + critic_layers.append(activation) + for layer_index in range(len(critic_hidden_dims)): + if layer_index == len(critic_hidden_dims) - 1: + critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1)) + else: + critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1])) + critic_layers.append(activation) + self.critic = nn.Sequential(*critic_layers) + + print(f"Actor MLP: {self.actor}") + print(f"Critic MLP: {self.critic}") + + # Action distribution (populated in update_distribution) + self.distribution = None + self.a = None + self.b = None + # disable args validation for speedup + Beta.set_default_validate_args(False) + + @staticmethod + # not used at the moment + def init_weights(sequential, scales): + [ + torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) + for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear)) + ] + + def reset(self, dones=None): + pass + + def forward(self): + raise NotImplementedError + + @property + def action_mean(self): + mode = self.a / (self.a + self.b) + mode_rescaled = mode * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0] + return mode_rescaled + + @property + def action_std(self): + return torch.sqrt(self.a * self.b / ((self.a + self.b + 1) * (self.a + self.b) ** 2)) + + @property + def actions_distribution(self): + # Alpha and beta concatenated on an extra dimension + return torch.stack([self.a, self.b], dim=-1) + + @property + def entropy(self): + return self.distribution.entropy().sum(dim=-1) + + def build_distribution(self, parameters): + # create distribution + return Beta(parameters[...,0], parameters[...,1]) + + def update_distribution(self, observations): + # compute mean + latent = self.actor(observations) + self.a = self.alpha_activation(self.alpha(latent)) + 1.0 + self.b = self.beta_activation(self.beta(latent)) + 1.0 + + # create distribution + self.distribution = Beta(self.a, self.b) + + def act(self, observations, **kwargs): + self.update_distribution(observations) + act = self.distribution.sample() + act_rescaled = act * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0] + return act_rescaled + + def get_actions_log_prob(self, actions): + # Unscale the actions to [0, 1] before computing the log probability. + unscaled_actions = (actions - self.clip_actions_range[0]) / (self.clip_actions_range[1] - self.clip_actions_range[0]) + # For numerical stability, clip the actions to [1e-5, 1 - 1e-5]. + unscaled_actions = torch.clamp(unscaled_actions, 1e-5, 1 - 1e-5) + return self.distribution.log_prob(unscaled_actions).sum(dim=-1) + + def act_inference(self, observations): + latent = self.actor(observations) + self.a = self.alpha_activation(self.alpha(latent)) + self.b = self.beta_activation(self.beta(latent)) + mode = self.a / (self.a + self.b) + mode_rescaled = mode * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0] + return mode_rescaled + + def evaluate(self, critic_observations, **kwargs): + value = self.critic(critic_observations) + return value + + def load_state_dict(self, state_dict, strict=True): + """Load the parameters of the actor-critic model. + + Args: + state_dict (dict): State dictionary of the model. + strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this + module's state_dict() function. + + Returns: + bool: Whether this training resumes a previous training. This flag is used by the `load()` function of + `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation). + """ + + super().load_state_dict(state_dict, strict=strict) + return True diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 6ca26fa3..5f8dea7f 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -16,6 +16,7 @@ from rsl_rl.env import VecEnv from rsl_rl.modules import ( ActorCritic, + ActorCriticBeta, ActorCriticRecurrent, EmpiricalNormalization, StudentTeacher, @@ -69,7 +70,7 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev # evaluate the policy class policy_class = eval(self.policy_cfg.pop("class_name")) - policy: ActorCritic | ActorCriticRecurrent | StudentTeacher | StudentTeacherRecurrent = policy_class( + policy: ActorCritic | ActorCriticBeta | ActorCriticRecurrent | StudentTeacher | StudentTeacherRecurrent = policy_class( num_obs, num_privileged_obs, self.env.num_actions, **self.policy_cfg ).to(self.device) diff --git a/rsl_rl/storage/rollout_storage.py b/rsl_rl/storage/rollout_storage.py index 42b8c9fd..b328c13d 100644 --- a/rsl_rl/storage/rollout_storage.py +++ b/rsl_rl/storage/rollout_storage.py @@ -8,7 +8,7 @@ import torch from rsl_rl.utils import split_and_pad_trajectories - +import copy class RolloutStorage: class Transition: @@ -22,7 +22,7 @@ def __init__(self): self.values = None self.actions_log_prob = None self.action_mean = None - self.action_sigma = None + self.actions_distribution = None self.hidden_states = None self.rnd_state = None @@ -39,6 +39,7 @@ def __init__( actions_shape, rnd_state_shape=None, device="cpu", + dist_size=2, ): # store inputs self.training_type = training_type @@ -71,7 +72,7 @@ def __init__( self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) - self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) + self.distributions_parameters = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, dist_size, device=self.device) self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) @@ -108,7 +109,7 @@ def add_transitions(self, transition: Transition): self.values[self.step].copy_(transition.values) self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) self.mu[self.step].copy_(transition.action_mean) - self.sigma[self.step].copy_(transition.action_sigma) + self.distributions_parameters[self.step].copy_(transition.actions_distribution) # For RND if self.rnd_state_shape is not None: @@ -203,7 +204,7 @@ def mini_batch_generator(self, num_mini_batches, num_epochs=8): old_actions_log_prob = self.actions_log_prob.flatten(0, 1) advantages = self.advantages.flatten(0, 1) old_mu = self.mu.flatten(0, 1) - old_sigma = self.sigma.flatten(0, 1) + old_distributions_parameters = self.distributions_parameters.flatten(0, 1) # For RND if self.rnd_state_shape is not None: @@ -228,7 +229,7 @@ def mini_batch_generator(self, num_mini_batches, num_epochs=8): old_actions_log_prob_batch = old_actions_log_prob[batch_idx] advantages_batch = advantages[batch_idx] old_mu_batch = old_mu[batch_idx] - old_sigma_batch = old_sigma[batch_idx] + old_distributions_parameters_batch = old_distributions_parameters[batch_idx] # -- For RND if self.rnd_state_shape is not None: @@ -237,7 +238,7 @@ def mini_batch_generator(self, num_mini_batches, num_epochs=8): rnd_state_batch = None # yield the mini-batch - yield obs_batch, privileged_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( + yield obs_batch, privileged_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_distributions_parameters_batch, ( None, None, ), None, rnd_state_batch @@ -282,7 +283,7 @@ def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8): actions_batch = self.actions[:, start:stop] old_mu_batch = self.mu[:, start:stop] - old_sigma_batch = self.sigma[:, start:stop] + old_distributions_parameters_batch = self.distributions_parameters[:, start:stop] returns_batch = self.returns[:, start:stop] advantages_batch = self.advantages[:, start:stop] values_batch = self.values[:, start:stop] @@ -308,7 +309,7 @@ def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8): hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch - yield obs_batch, privileged_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( + yield obs_batch, privileged_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_distributions_parameters_batch, ( hid_a_batch, hid_c_batch, ), masks_batch, rnd_state_batch