diff --git a/sb3_contrib/crossq/crossq.py b/sb3_contrib/crossq/crossq.py index 1b7f90b8..4798b84d 100644 --- a/sb3_contrib/crossq/crossq.py +++ b/sb3_contrib/crossq/crossq.py @@ -10,7 +10,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from torch.nn import functional as F -from sb3_contrib.crossq.policies import Actor, CrossQCritic, CrossQPolicy, MlpPolicy +from sb3_contrib.crossq.policies import Actor, CrossQCritic, CrossQPolicy, MlpPolicy, MultiInputPolicy SelfCrossQ = TypeVar("SelfCrossQ", bound="CrossQ") @@ -67,7 +67,8 @@ class CrossQ(OffPolicyAlgorithm): policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, - # TODO: Implement CnnPolicy and MultiInputPolicy + "MultiInputPolicy": MultiInputPolicy, + # TODO: Implement CnnPolicy } policy: CrossQPolicy actor: Actor @@ -235,7 +236,14 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # # 2. From a computational perspective a single forward pass is simply more efficient than # two sequential forward passes. - all_obs = th.cat([replay_data.observations, replay_data.next_observations], dim=0) + + if isinstance(replay_data.observations, dict): + all_obs = { + key: th.cat([replay_data.observations[key], replay_data.next_observations[key]], dim=0) + for key in replay_data.observations.keys() + } + else: + all_obs = th.cat([replay_data.observations, replay_data.next_observations], dim=0) all_actions = th.cat([replay_data.actions, next_actions], dim=0) # Update critic BN stats self.critic.set_bn_training_mode(True) @@ -331,3 +339,4 @@ def _get_torch_save_params(self) -> tuple[list[str], list[str]]: else: saved_pytorch_variables = ["ent_coef_tensor"] return state_dicts, saved_pytorch_variables + diff --git a/sb3_contrib/crossq/policies.py b/sb3_contrib/crossq/policies.py index eedaa365..f526c92a 100644 --- a/sb3_contrib/crossq/policies.py +++ b/sb3_contrib/crossq/policies.py @@ -9,6 +9,7 @@ from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, FlattenExtractor, + CombinedExtractor, create_mlp, get_actor_critic_arch, ) @@ -529,3 +530,77 @@ def set_training_mode(self, mode: bool) -> None: MlpPolicy = CrossQPolicy + +class MultiInputPolicy(CrossQPolicy): + """ + Policy class (with both actor and critic) for CrossQ. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_quantiles: Number of quantiles for the critic. + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Box, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + batch_norm: bool = True, + batch_norm_momentum: float = 0.01, # Note: Jax implementation is 1 - momentum = 0.99 + batch_norm_eps: float = 0.001, + renorm_warmup_steps: int = 100_000, + use_sde: bool = False, + log_std_init: float = -3, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = False, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + batch_norm, + batch_norm_momentum, + batch_norm_eps, + renorm_warmup_steps, + use_sde, + log_std_init, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + ) \ No newline at end of file diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 429255a1..7ed483bc 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -9,7 +9,7 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize -from sb3_contrib import QRDQN, TQC, TRPO +from sb3_contrib import QRDQN, TQC, TRPO, CrossQ class DummyDictEnv(gym.Env): @@ -89,7 +89,7 @@ def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only): check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) -@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, CrossQ]) def test_consistency(model_class): """ Make sure that dict obs with vector only vs using flatten obs is equivalent. @@ -134,7 +134,7 @@ def test_consistency(model_class): assert np.allclose(action_1, action_2) -@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, CrossQ]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_spaces(model_class, channel_last): """ @@ -179,7 +179,7 @@ def test_dict_spaces(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, CrossQ]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_vec_framestack(model_class, channel_last): """ @@ -228,7 +228,7 @@ def test_dict_vec_framestack(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, CrossQ]) def test_vec_normalize(model_class): """ Additional tests to check observation space support