Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
130 changes: 130 additions & 0 deletions sb3_contrib/common/hybrid/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import numpy as np
import torch as th
from torch import nn
from typing import Any, Optional, TypeVar, Union
from stable_baselines3.common.distributions import Distribution
from gymnasium import spaces


SelfHybridDistribution = TypeVar("SelfHybridDistribution", bound="HybridDistribution")


class HybridDistributionNet(nn.Module):
"""
Base class for hybrid distributions that handle both discrete and continuous actions.
This class should be extended to implement specific hybrid distributions.
"""

def __init__(self, latent_dim: int, categorical_dimensions: np.ndarray, n_continuous: int):
super().__init__()
# For discrete action space
self.categorical_nets = nn.ModuleList([nn.Linear(latent_dim, out_dim) for out_dim in categorical_dimensions])
# For continuous action space
self.gaussian_net = nn.Linear(latent_dim, n_continuous)

def forward(self, latent: th.Tensor) -> tuple[list[th.Tensor], th.Tensor]:
"""
Forward pass through all categorical nets and the gaussian net.

:param latent: Latent tensor input
:return: Tuple (list of categorical outputs, gaussian output)
"""
categorical_outputs = [net(latent) for net in self.categorical_nets]
gaussian_output = self.gaussian_net(latent)
return categorical_outputs, gaussian_output


class HybridDistribution(Distribution):
def __init__(self, categorical_dimensions: np.ndarray, n_continuous: int):
super().__init__()
self.categorical_dimensions = categorical_dimensions
self.n_continuous = n_continuous
self.categorical_dists = None
self.gaussian_dist = None

def proba_distribution_net(self, latent_dim: int) -> Union[nn.Module, tuple[nn.Module, nn.Parameter]]:
"""Create the layers and parameters that represent the distribution.

Subclasses must define this, but the arguments and return type vary between
concrete classes."""
action_net = HybridDistributionNet(latent_dim, self.categorical_dimensions)
return action_net

def proba_distribution(self: SelfHybridDistribution, *args, **kwargs) -> SelfHybridDistribution:
"""Set parameters of the distribution.

:return: self
"""

def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
Returns the log likelihood

:param x: the taken action
:return: The log likelihood of the distribution
"""

def entropy(self) -> Optional[th.Tensor]:
"""
Returns Shannon's entropy of the probability

:return: the entropy, or None if no analytical form is known
"""

def sample(self) -> th.Tensor:
"""
Returns a sample from the probability distribution

:return: the stochastic action
"""

def mode(self) -> th.Tensor:
"""
Returns the most likely action (deterministic output)
from the probability distribution

:return: the stochastic action
"""

# TODO: this is not abstract in superclass, you can also not re-implement it --> check
def get_actions(self, deterministic: bool = False) -> th.Tensor:
"""
Return actions according to the probability distribution.

:param deterministic:
:return:
"""
if deterministic:
return self.mode()
return self.sample()

def actions_from_params(self, *args, **kwargs) -> th.Tensor:
"""
Returns samples from the probability distribution
given its parameters.

:return: actions
"""

def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]:
"""
Returns samples and the associated log probabilities
from the probability distribution given its parameters.

:return: actions and log prob
"""


def make_hybrid_proba_distribution(action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box]) -> HybridDistribution:
"""
Create a hybrid probability distribution for the given action space.

:param action_space: Tuple Action space containing a MultiDiscrete action space and a Box action space.
:return: A HybridDistribution object that handles the hybrid action space.
"""
assert len(action_space[1].shape) == 1, "Continuous action space must have a monodimensional shape (e.g., (n,))"
return HybridDistribution(
categorical_dimensions=len(action_space[0].nvec),
n_continuous=action_space[1].shape[0]
)

129 changes: 129 additions & 0 deletions sb3_contrib/common/hybrid/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Any, Optional, Union
import warnings
from stable_baselines3.common.policies import BasePolicy
from gymnasium import spaces
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
import torch as th
from torch import nn
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
MlpExtractor,
NatureCNN,
)

from sb3_contrib.common.hybrid.distributions import make_hybrid_proba_distribution


class HybridActorCriticPolicy(BasePolicy):
"""
Policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.

:param observation_space: Observation space
:param action_space: Tuple Action space containing a MultiDiscrete action space and a Box action space.
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param log_std_init: Initial value for the log standard deviation
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box],
lr_schedule: Schedule,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
log_std_init: float = 0.0,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs["eps"] = 1e-5

super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
normalize_images=normalize_images,
squash_output=False,
)

# assert that the action space is compatible with its type hint
assert isinstance(action_space, spaces.Tuple), "Action space must be a gymnasium.spaces.Tuple"
assert len(action_space.spaces) == 2, "Action space Tuple must contain exactly two spaces"
assert isinstance(action_space.spaces[0], spaces.MultiDiscrete), "First element of action space Tuple must be MultiDiscrete"
assert isinstance(action_space.spaces[1], spaces.Box), "Second element of action space Tuple must be Box"

if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]

# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = dict(pi=[64, 64], vf=[64, 64])

self.net_arch = net_arch
self.activation_fn = activation_fn
self.ortho_init = ortho_init

# features extractor
self.share_features_extractor = share_features_extractor
self.features_extractor = self.make_features_extractor()
self.features_dim = self.features_extractor.features_dim
if self.share_features_extractor:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.features_extractor
else:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()

self.log_std_init = log_std_init

# Action distribution
self.action_dist = make_hybrid_proba_distribution(action_space)

# TODO: self._build()


# TODO: check superclass
class HybridActorCriticCnnPolicy(HybridActorCriticPolicy):
pass


# TODO: check superclass
class HybridMultiInputActorCriticPolicy(HybridActorCriticPolicy):
pass
4 changes: 4 additions & 0 deletions sb3_contrib/ppo_hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.ppo_hybrid.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.ppo_hybrid.ppo_hybrid import HybridPPO

__all__ = ["CnnPolicy", "HybridPPO", "MlpPolicy", "MultiInputPolicy"]
9 changes: 9 additions & 0 deletions sb3_contrib/ppo_hybrid/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from sb3_contrib.common.hybrid.policies import (
HybridActorCriticPolicy,
HybridActorCriticCnnPolicy,
HybridMultiInputActorCriticPolicy,
)

MlpPolicy = HybridActorCriticPolicy
CnnPolicy = HybridActorCriticCnnPolicy
MultiInputPolicy = HybridMultiInputActorCriticPolicy
69 changes: 69 additions & 0 deletions sb3_contrib/ppo_hybrid/ppo_hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import ClassVar
from stable_baselines3.ppo import PPO
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy


class HybrudPPO(PPO):
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}

def __init__(
self,
policy,
env,
learning_rate=0.0003,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
clip_range_vf=None,
normalize_advantage=True,
ent_coef=0,
vf_coef=0.5,
max_grad_norm=0.5,
use_sde=False,
sde_sample_freq=-1,
rollout_buffer_class=None,
rollout_buffer_kwargs=None,
target_kl=None,
stats_window_size=100,
tensorboard_log=None,
policy_kwargs=None,
verbose=0,
seed=None,
device="auto",
_init_setup_model=True,
):
super().__init__(
policy,
env,
learning_rate,
n_steps,
batch_size,
n_epochs,
gamma,
gae_lambda,
clip_range,
clip_range_vf,
normalize_advantage,
ent_coef,
vf_coef,
max_grad_norm,
use_sde,
sde_sample_freq,
rollout_buffer_class,
rollout_buffer_kwargs,
target_kl,
stats_window_size,
tensorboard_log,
policy_kwargs,
verbose,
seed,
device,
_init_setup_model,
)
Loading