diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index cd4851d0..a98488e3 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -30,6 +30,20 @@ Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment. model.learn(total_timesteps=10_000, log_interval=4) model.save("qrdqn_cartpole") +IQN +--- + +Train a Implicit Quantile Networks (IQN) agent on the CartPole environment. + +.. code-block:: python + + from sb3_contrib import IQN + + policy_kwargs = dict(n_quantiles=32) + model = IQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + model.learn(total_timesteps=10_000, log_interval=4) + model.save("iqn_cartpole") + MaskablePPO ----------- diff --git a/docs/index.rst b/docs/index.rst index 5e322652..63415422 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :caption: RL Algorithms modules/ars + modules/iqn modules/ppo_mask modules/ppo_recurrent modules/qrdqn diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 3fbd28d8..91aad7e5 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,6 +1,7 @@ import os from sb3_contrib.ars import ARS +from sb3_contrib.iqn import IQN from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN @@ -14,6 +15,7 @@ __all__ = [ "ARS", + "IQN", "MaskablePPO", "RecurrentPPO", "QRDQN", diff --git a/sb3_contrib/iqn/__init__.py b/sb3_contrib/iqn/__init__.py new file mode 100644 index 00000000..06e76ea8 --- /dev/null +++ b/sb3_contrib/iqn/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.iqn.iqn import IQN +from sb3_contrib.iqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["IQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/iqn/iqn.py b/sb3_contrib/iqn/iqn.py new file mode 100644 index 00000000..8e3187c3 --- /dev/null +++ b/sb3_contrib/iqn/iqn.py @@ -0,0 +1,282 @@ +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.preprocessing import maybe_transpose +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update + +from sb3_contrib.common.utils import quantile_huber_loss +from sb3_contrib.iqn.policies import CnnPolicy, IQNPolicy, MlpPolicy, MultiInputPolicy + +SelfIQN = TypeVar("SelfIQN", bound="IQN") + + +class IQN(OffPolicyAlgorithm): + """ + Implicit Quantile Network (IQN) + Paper: https://arxiv.org/abs/1806.06923 + Default hyperparameters are taken from the paper and are tuned for Atari games. + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update + :param gamma: the discount factor + :param num_tau_samples: Number of samples used to estimate the current quantiles + :param num_tau_prime_samples: Number of samples used to estimate the next quantiles + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + :param gradient_steps: How many gradient steps to do after each rollout + (see ``train_freq`` and ``n_episodes_rollout``) + Set to ``-1`` means to do as many gradient steps as steps done in the environment + during the rollout. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param target_update_interval: update the target network every ``target_update_interval`` + environment steps. + :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced + :param exploration_initial_eps: initial value of random action probability + :param exploration_final_eps: final value of random action probability + :param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping) + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + + def __init__( + self, + policy: Union[str, Type[IQNPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 5e-5, + buffer_size: int = 1000000, # 1e6 + learning_starts: int = 50000, + batch_size: int = 32, + tau: float = 1.0, + gamma: float = 0.99, + num_tau_samples: int = 32, + num_tau_prime_samples: int = 64, + train_freq: int = 4, + gradient_steps: int = 1, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + target_update_interval: int = 10000, + exploration_fraction: float = 0.005, + exploration_initial_eps: float = 1.0, + exploration_final_eps: float = 0.01, + max_grad_norm: Optional[float] = None, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super().__init__( + policy, + env, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + action_noise=None, # No action noise + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + seed=seed, + sde_support=False, + optimize_memory_usage=optimize_memory_usage, + supported_action_spaces=(spaces.Discrete,), + support_multi_env=True, + ) + + self.num_tau_samples = num_tau_samples + self.num_tau_prime_samples = num_tau_prime_samples + + self.exploration_initial_eps = exploration_initial_eps + self.exploration_final_eps = exploration_final_eps + self.exploration_fraction = exploration_fraction + self.target_update_interval = target_update_interval + self.max_grad_norm = max_grad_norm + # "epsilon" for the epsilon-greedy exploration + self.exploration_rate = 0.0 + # Linear schedule will be defined in `_setup_model()` + self.exploration_schedule: Schedule + self.policy: IQNPolicy # type: ignore[assignment] + + if "optimizer_class" not in self.policy_kwargs: + self.policy_kwargs["optimizer_class"] = th.optim.Adam + # Proposed in the QR-DQN paper where `batch_size = 32` + self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size) + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + self._create_aliases() + # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 + self.batch_norm_stats = get_parameters_by_name(self.quantile_net, ["running_"]) + self.batch_norm_stats_target = get_parameters_by_name(self.quantile_net_target, ["running_"]) + self.exploration_schedule = get_linear_fn( + self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction + ) + + def _create_aliases(self) -> None: + self.quantile_net = self.policy.quantile_net + self.quantile_net_target = self.policy.quantile_net_target + self.n_quantiles = self.policy.n_quantiles + + def _on_step(self) -> None: + """ + Update the exploration rate and target network if needed. + This method is called in ``collect_rollouts()`` after each step in the environment. + """ + if self.num_timesteps % self.target_update_interval == 0: + polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau) + # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 + polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) + + self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) + self.logger.record("rollout/exploration_rate", self.exploration_rate) + + def train(self, gradient_steps: int, batch_size: int = 100) -> None: + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update learning rate according to schedule + self._update_learning_rate(self.policy.optimizer) + + losses = [] + for _ in range(gradient_steps): + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + + with th.no_grad(): + # Compute the quantiles of next observation + next_quantiles = self.quantile_net_target(replay_data.next_observations, self.n_quantiles) + # Compute the greedy actions which maximize the next Q values + next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True) + # Make "num_tau_prime_samples" copies of actions, and reshape to (batch_size, num_tau_prime_samples, 1) + next_greedy_actions = next_greedy_actions.expand(batch_size, self.num_tau_prime_samples, 1) + # Compute the quantiles of next observation, but with another number of tau samples + next_quantiles = self.quantile_net_target(replay_data.next_observations, self.num_tau_prime_samples) + # Follow greedy policy: use the one with the highest Q values + next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2) + # 1-step TD target + target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles + + # Get current quantile estimates + current_quantiles = self.quantile_net(replay_data.observations, self.num_tau_samples) + # Make "num_tau_samples" copies of actions, and reshape to (batch_size, num_tau_samples, 1). + actions = replay_data.actions[..., None].long().expand(batch_size, self.num_tau_samples, 1) + # Retrieve the quantiles for the actions from the replay buffer + current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2) + + # Compute Quantile Huber loss, summing over a quantile dimension as in the paper. + loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True) + losses.append(loss.item()) + + # Optimize the policy + self.policy.optimizer.zero_grad() + loss.backward() + # Clip gradient norm + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + # Increase update counter + self._n_updates += gradient_steps + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/loss", np.mean(losses)) + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + if not deterministic and np.random.rand() < self.exploration_rate: + if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): + if isinstance(self.observation_space, spaces.Dict): + n_batch = observation[list(observation.keys())[0]].shape[0] + else: + n_batch = observation.shape[0] + action = np.array([self.action_space.sample() for _ in range(n_batch)]) + else: + action = np.array(self.action_space.sample()) + else: + action, state = self.policy.predict(observation, state, episode_start, deterministic) + return action, state + + def learn( + self: SelfIQN, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "IQN", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfIQN: + + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + def _excluded_save_params(self) -> List[str]: + return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + state_dicts = ["policy", "policy.optimizer"] + + return state_dicts, [] diff --git a/sb3_contrib/iqn/policies.py b/sb3_contrib/iqn/policies.py new file mode 100644 index 00000000..91da8a2c --- /dev/null +++ b/sb3_contrib/iqn/policies.py @@ -0,0 +1,380 @@ +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch as th +from gym import spaces +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + NatureCNN, + create_mlp, +) +from stable_baselines3.common.type_aliases import Schedule +from torch import nn + + +def unflatten(input: th.Tensor, dim: int, sizes: Tuple[int, ...]) -> th.Tensor: + # The purpose of this function is to remember thatonce the minimum + # version of torch is above 1.13, we can simply use th.unflatten. + if th.__version__ >= "1.13": + return th.unflatten(input, dim, sizes) + else: + return nn.Unflatten(dim, (sizes))(input) + + +class CosineEmbeddingNetwork(nn.Module): + """ + Computes the embeddings of tau values using cosine functions. + + Take a tensor of shape (batch_size, num_tau_samples) representing the tau values, and return + a tensor of shape (batch_size, num_tau_samples, features_dim) representing the embeddings of tau values. + + :param num_cosine: Number of cosines basis functions + :param features_dim: Dimension of the embedding + """ + + def __init__(self, num_cosine: int, features_dim: int) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Linear(num_cosine, features_dim), + nn.ReLU(), + ) + self.num_cosine = num_cosine + + def forward(self, taus: th.Tensor) -> th.Tensor: + # Compute cos(i * pi * tau) + i_pi = th.pi * th.arange(start=1, end=self.num_cosine + 1, device=taus.device) + i_pi = i_pi.reshape(1, 1, self.num_cosine) # (1, 1, num_cosines) + taus = th.unsqueeze(taus, dim=-1) # (batch_size, num_tau_samples, 1) + cosines = th.cos(taus * i_pi) # (batch_size, num_tau_samples, num_cosines) + + # Compute embeddings of taus + cosines = th.flatten(cosines, end_dim=1) # (batch_size * num_tau_samples, num_cosines) + tau_embeddings = self.net(cosines) # (batch_size * num_tau_samples, features_dim) + return unflatten(tau_embeddings, dim=0, sizes=(-1, taus.shape[1])) # (batch_size, num_tau_samples, features_dim) + + +class QuantileNetwork(BasePolicy): + """ + Quantile network for IQN + + :param observation_space: Observation space + :param action_space: Action space + :param features_extractor: + :param features_dim: + :param n_quantiles: Number of quantiles + :param num_cosine: Number of cosines basis functions + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + features_extractor: BaseFeaturesExtractor, + features_dim: int, + n_quantiles: int = 64, + num_cosine: int = 64, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + ): + super().__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + ) + + if net_arch is None: + net_arch = [64, 64] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.features_extractor = features_extractor + self.features_dim = features_dim + self.n_quantiles = n_quantiles + self.num_cosine = num_cosine + action_dim = self.action_space.n # number of actions + quantile_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) + self.quantile_net = nn.Sequential(*quantile_net) + self.cosine_net = CosineEmbeddingNetwork(self.num_cosine, self.features_dim) + + def forward(self, obs: th.Tensor, num_tau_samples: int) -> th.Tensor: + """ + Predict the quantiles. + + :param obs: Observation + :return: The estimated quantiles for each action. + """ + features = self.extract_features(obs, self.features_extractor) + taus = th.rand(features.shape[0], num_tau_samples, device=self.device) + tau_embeddings = self.cosine_net(taus) + # Compute the embeddings and taus + features = th.unsqueeze(features, dim=1) # (batch_size, 1, features_dim) + features = features * tau_embeddings # (batch_size, M, features_dim) + + # Compute the quantile values + features = th.flatten(features, end_dim=1) # (batch_size * M, features_dim) + quantiles = self.quantile_net(features) + return unflatten(quantiles, dim=0, sizes=(-1, tau_embeddings.shape[1])) # (batch_size, M, num_actions) + + def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: + q_values = self(observation, self.n_quantiles).mean(dim=1) + # Greedy action + action = q_values.argmax(dim=1).reshape(-1) + return action + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + n_quantiles=self.n_quantiles, + num_cosine=self.num_cosine, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor, + cosine_net=self.cosine_net, + ) + ) + return data + + +class IQNPolicy(BasePolicy): + """ + Policy class with quantile and target networks for IQN. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_quantiles: Number of quantiles + :param num_cosine: Number of cosines basis functions + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + n_quantiles: int = 64, + num_cosine: int = 64, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + normalize_images=normalize_images, + ) + + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = [64, 64] + + self.n_quantiles = n_quantiles + self.num_cosine = num_cosine + self.net_arch = net_arch + self.activation_fn = activation_fn + + self.net_args = { + "observation_space": self.observation_space, + "action_space": self.action_space, + "n_quantiles": self.n_quantiles, + "num_cosine": self.num_cosine, + "net_arch": self.net_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + } + + self.quantile_net: QuantileNetwork + self.quantile_net_target: QuantileNetwork + self._build(lr_schedule) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the network and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self.quantile_net = self.make_quantile_net() + self.quantile_net_target = self.make_quantile_net() + self.quantile_net_target.load_state_dict(self.quantile_net.state_dict()) + self.quantile_net_target.set_training_mode(False) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class( + self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs + ) # type:ignore[call-arg] # Assume that all optimizers have lr as argument + + def make_quantile_net(self) -> QuantileNetwork: + # Make sure we always have separate networks for features extractors etc + net_args = self._update_features_extractor(self.net_args, features_extractor=None) + return QuantileNetwork(**net_args).to(self.device) + + def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: + return self.quantile_net._predict(obs, deterministic=deterministic) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + n_quantiles=self.net_args["n_quantiles"], + num_cosine=self.net_args["num_cosine"], + net_arch=self.net_args["net_arch"], + activation_fn=self.net_args["activation_fn"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + This affects certain modules, such as batch normalisation and dropout. + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.quantile_net.set_training_mode(mode) + self.training = mode + + +MlpPolicy = IQNPolicy + + +class CnnPolicy(IQNPolicy): + """ + Policy class for IQN when using images as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_quantiles: Number of quantiles + :param num_cosine: Number of cosines basis functions + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + n_quantiles: int = 64, + num_cosine: int = 64, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + n_quantiles, + num_cosine, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MultiInputPolicy(IQNPolicy): + """ + Policy class for IQN when using dict observations as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_quantiles: Number of quantiles + :param num_cosine: Number of cosines basis functions + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + n_quantiles: int = 64, + num_cosine: int = 64, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + n_quantiles, + num_cosine, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) diff --git a/setup.cfg b/setup.cfg index b4464589..02d44c72 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,8 @@ exclude = (?x)( | sb3_contrib/ars/ars.py$ | sb3_contrib/qrdqn/qrdqn.py$ | sb3_contrib/qrdqn/policies.py$ + | sb3_contrib/iqn/iqn.py$ + | sb3_contrib/iqn/policies.py$ | sb3_contrib/common/recurrent/policies.py$ | sb3_contrib/common/recurrent/buffers.py$ | sb3_contrib/common/maskable/distributions.py$ diff --git a/tests/test_cnn.py b/tests/test_cnn.py index be8758a8..63551396 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -9,11 +9,11 @@ from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecTransposeImage, is_vecenv_wrapped -from sb3_contrib import QRDQN, TQC, TRPO, MaskablePPO, RecurrentPPO +from sb3_contrib import IQN, QRDQN, TQC, TRPO, MaskablePPO, RecurrentPPO from sb3_contrib.common.wrappers import ActionMasker -@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO]) +@pytest.mark.parametrize("model_class", [TQC, IQN, QRDQN, TRPO]) @pytest.mark.parametrize("share_features_extractor", [True, False]) def test_cnn(tmp_path, model_class, share_features_extractor): SAVE_NAME = "cnn_model.zip" @@ -27,7 +27,7 @@ def test_cnn(tmp_path, model_class, share_features_extractor): discrete=model_class not in {TQC}, ) kwargs = dict(policy_kwargs=dict(share_features_extractor=share_features_extractor)) - if model_class in {TQC, QRDQN}: + if model_class in {TQC, IQN, QRDQN}: # share_features_extractor is checked later for offpolicy algorithms if share_features_extractor: return @@ -49,7 +49,7 @@ def test_cnn(tmp_path, model_class, share_features_extractor): assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) # Test stochastic predict with channel last input - if model_class == QRDQN: + if model_class in {IQN, QRDQN}: model.exploration_rate = 0.9 for _ in range(10): @@ -68,9 +68,9 @@ def test_cnn(tmp_path, model_class, share_features_extractor): os.remove(str(tmp_path / SAVE_NAME)) -def patch_qrdqn_names_(model): - # Small hack to make the test work with QRDQN - if isinstance(model, QRDQN): +def patch_quantiles_dqn_names_(model): + # Small hack to make the test work with IQN and QRDQN + if isinstance(model, (QRDQN, IQN)): model.critic = model.quantile_net model.critic_target = model.quantile_net_target @@ -85,15 +85,15 @@ def params_should_differ(params, other_params): assert not th.allclose(param, other_param) -@pytest.mark.parametrize("model_class", [TQC, QRDQN]) +@pytest.mark.parametrize("model_class", [TQC, IQN, QRDQN]) @pytest.mark.parametrize("share_features_extractor", [True, False]) def test_feature_extractor_target_net(model_class, share_features_extractor): - if model_class == QRDQN and share_features_extractor: + if model_class in {IQN, QRDQN} and share_features_extractor: pytest.skip() env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC}) - if model_class in {TQC, QRDQN}: + if model_class in {TQC, IQN, QRDQN}: # Avoid memory error when using replay buffer # Reduce the size of the features and the number of quantiles kwargs = dict( @@ -101,19 +101,19 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): learning_starts=100, policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)), ) - if model_class != QRDQN: + if model_class not in {IQN, QRDQN}: kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor model = model_class("CnnPolicy", env, seed=0, **kwargs) - patch_qrdqn_names_(model) + patch_quantiles_dqn_names_(model) if share_features_extractor: # Check that the objects are the same and not just copied assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor) else: # Check that the objects differ - if model_class != QRDQN: + if model_class not in {IQN, QRDQN}: assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor) # Critic and target should be equal at the begginning of training @@ -127,7 +127,7 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): # Re-initialize and collect some random data (without doing gradient steps) model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10) - patch_qrdqn_names_(model) + patch_quantiles_dqn_names_(model) original_param = deepcopy(list(model.critic.parameters())) original_target_param = deepcopy(list(model.critic_target.parameters())) @@ -149,9 +149,9 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): model.lr_schedule = lambda _: 0.0 # Re-activate polyak update model.tau = 0.01 - # Special case for QRDQN: target net is updated in the `collect_rollouts()` + # Special case for IQN and QRDQN: target net is updated in the `collect_rollouts()` # not the `train()` method - if model_class == QRDQN: + if model_class in {IQN, QRDQN}: model.target_update_interval = 1 model._on_step() @@ -164,7 +164,7 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): params_should_match(original_param, model.critic.parameters()) -@pytest.mark.parametrize("model_class", [TRPO, MaskablePPO, RecurrentPPO, QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [TRPO, MaskablePPO, RecurrentPPO, IQN, QRDQN, TQC]) @pytest.mark.parametrize("normalize_images", [True, False]) def test_image_like_input(model_class, normalize_images): """ diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py index 458d3f06..b3a8bea7 100644 --- a/tests/test_deterministic.py +++ b/tests/test_deterministic.py @@ -3,7 +3,7 @@ from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import VecNormalize -from sb3_contrib import ARS, QRDQN, TQC, RecurrentPPO +from sb3_contrib import ARS, IQN, QRDQN, TQC, RecurrentPPO from sb3_contrib.common.vec_env import AsyncEval N_STEPS_TRAINING = 500 @@ -11,7 +11,7 @@ ARS_MULTI = "ars_multi" -@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI, RecurrentPPO]) +@pytest.mark.parametrize("algo", [ARS, IQN, QRDQN, TQC, ARS_MULTI, RecurrentPPO]) def test_deterministic_training_common(algo): results = [[], []] rewards = [[], []] @@ -27,7 +27,7 @@ def test_deterministic_training_common(algo): if algo in [TQC]: kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4}) else: - if algo == QRDQN: + if algo in {IQN, QRDQN}: env_id = "CartPole-v1" kwargs.update({"learning_starts": 100, "target_update_interval": 100}) elif algo == ARS: diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index eada97d5..0a2e2aed 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -6,7 +6,7 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize -from sb3_contrib import QRDQN, TQC, TRPO +from sb3_contrib import IQN, QRDQN, TQC, TRPO class DummyDictEnv(gym.Env): @@ -78,13 +78,13 @@ def render(self, mode="human"): pass -@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) +@pytest.mark.parametrize("model_class", [IQN, QRDQN, TQC, TRPO]) def test_consistency(model_class): """ Make sure that dict obs with vector only vs using flatten obs is equivalent. This ensures notable that the network architectures are the same. """ - use_discrete_actions = model_class == QRDQN + use_discrete_actions = model_class in {IQN, QRDQN} dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) @@ -106,7 +106,7 @@ def test_consistency(model_class): train_freq=8, gradient_steps=1, ) - if model_class == QRDQN: + if model_class in {IQN, QRDQN}: kwargs["learning_starts"] = 0 dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs) @@ -124,7 +124,7 @@ def test_consistency(model_class): assert np.allclose(action_1, action_2) -@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) +@pytest.mark.parametrize("model_class", [IQN, QRDQN, TQC, TRPO]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_spaces(model_class, channel_last): """ @@ -159,7 +159,7 @@ def test_dict_spaces(model_class, channel_last): train_freq=8, gradient_steps=1, ) - if model_class == QRDQN: + if model_class in {IQN, QRDQN}: kwargs["learning_starts"] = 0 model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) @@ -169,7 +169,7 @@ def test_dict_spaces(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) +@pytest.mark.parametrize("model_class", [IQN, QRDQN, TQC, TRPO]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_vec_framestack(model_class, channel_last): """ @@ -208,7 +208,7 @@ def test_dict_vec_framestack(model_class, channel_last): train_freq=8, gradient_steps=1, ) - if model_class == QRDQN: + if model_class in {IQN, QRDQN}: kwargs["learning_starts"] = 0 model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) @@ -218,13 +218,13 @@ def test_dict_vec_framestack(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) +@pytest.mark.parametrize("model_class", [IQN, QRDQN, TQC, TRPO]) def test_vec_normalize(model_class): """ Additional tests to check observation space support for GoalEnv and VecNormalize using MultiInputPolicy. """ - env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == QRDQN), 100)]) + env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class in {IQN, QRDQN}), 100)]) env = VecNormalize(env, norm_obs_keys=["vec"]) kwargs = {} @@ -248,7 +248,7 @@ def test_vec_normalize(model_class): train_freq=8, gradient_steps=1, ) - if model_class == QRDQN: + if model_class in {IQN, QRDQN}: kwargs["learning_starts"] = 0 model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) diff --git a/tests/test_identity.py b/tests/test_identity.py index 6ad03174..a69b7d40 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -4,18 +4,18 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv -from sb3_contrib import QRDQN, TRPO +from sb3_contrib import IQN, QRDQN, TRPO DIM = 4 -@pytest.mark.parametrize("model_class", [QRDQN, TRPO]) +@pytest.mark.parametrize("model_class", [IQN, QRDQN, TRPO]) @pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) kwargs = {} n_steps = 1500 - if model_class == QRDQN: + if model_class in {IQN, QRDQN}: kwargs = dict( learning_starts=0, policy_kwargs=dict(n_quantiles=25, net_arch=[32]), diff --git a/tests/test_run.py b/tests/test_run.py index 6753ebb3..6b44aeb4 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -3,7 +3,7 @@ from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecNormalize -from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO +from sb3_contrib import ARS, IQN, QRDQN, TQC, TRPO, MaskablePPO from sb3_contrib.common.envs import InvalidActionEnvDiscrete from sb3_contrib.common.vec_env import AsyncEval @@ -61,6 +61,19 @@ def test_qrdqn(): model.learn(total_timesteps=500) +def test_iqn(): + model = IQN( + "MlpPolicy", + "CartPole-v1", + policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]), + learning_starts=100, + buffer_size=500, + learning_rate=3e-4, + verbose=1, + ) + model.learn(total_timesteps=500) + + @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_trpo(env_id): model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) @@ -115,7 +128,7 @@ def test_ars_n_top(n_top): model.learn(total_timesteps=500) -@pytest.mark.parametrize("model_class", [TQC, QRDQN]) +@pytest.mark.parametrize("model_class", [TQC, IQN, QRDQN]) def test_offpolicy_multi_env(model_class): if model_class in [TQC]: env_id = "Pendulum-v1" diff --git a/tests/test_save_load.py b/tests/test_save_load.py index b2a62d27..099c4573 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -12,16 +12,16 @@ from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv -from sb3_contrib import ARS, QRDQN, TQC, TRPO +from sb3_contrib import ARS, IQN, QRDQN, TQC, TRPO -MODEL_LIST = [ARS, QRDQN, TQC, TRPO] +MODEL_LIST = [ARS, IQN, QRDQN, TQC, TRPO] def select_env(model_class: BaseAlgorithm) -> gym.Env: """ - Selects an environment with the correct action space as QRDQN only supports discrete action space + Selects an environment with the correct action space as IQN and QRDQN only support discrete action space """ - if model_class == QRDQN: + if model_class in {IQN, QRDQN}: return IdentityEnv(10) else: return IdentityEnvBox(10) @@ -42,7 +42,7 @@ def test_save_load(tmp_path, model_class): policy_kwargs = dict(net_arch=[16]) - if model_class in {QRDQN, TQC}: + if model_class in {IQN, QRDQN, TQC}: policy_kwargs.update(dict(n_quantiles=20)) # create model @@ -171,13 +171,13 @@ def test_set_env(model_class): :param model_class: (BaseAlgorithm) A RL model """ - # use discrete for QRDQN + # use discrete for IQN and QRDQN env = DummyVecEnv([lambda: select_env(model_class)]) env2 = DummyVecEnv([lambda: select_env(model_class)]) env3 = select_env(model_class) kwargs = dict(policy_kwargs=dict(net_arch=[16])) - if model_class in {TQC, QRDQN}: + if model_class in {TQC, IQN, QRDQN}: kwargs.update(dict(learning_starts=100)) kwargs["policy_kwargs"].update(dict(n_quantiles=20)) @@ -273,7 +273,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): if policy_str == "MlpPolicy": env = select_env(model_class) else: - if model_class in [TQC, QRDQN]: + if model_class in [TQC, IQN, QRDQN]: # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict( @@ -286,10 +286,10 @@ def test_save_load_policy(tmp_path, model_class, policy_str): n_steps=128, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), ) - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class in {IQN, QRDQN}) # Reduce number of quantiles for faster tests - if model_class in [TQC, QRDQN]: + if model_class in [TQC, IQN, QRDQN]: kwargs["policy_kwargs"].update(dict(n_quantiles=20)) env = DummyVecEnv([lambda: env]) @@ -366,7 +366,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): os.remove(tmp_path / "actor.pkl") -@pytest.mark.parametrize("model_class", [QRDQN]) +@pytest.mark.parametrize("model_class", [IQN, QRDQN]) @pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) def test_save_load_q_net(tmp_path, model_class, policy_str): """ @@ -379,7 +379,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str): if policy_str == "MlpPolicy": env = select_env(model_class) else: - if model_class in [QRDQN]: + if model_class in [IQN, QRDQN]: # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict( @@ -387,10 +387,10 @@ def test_save_load_q_net(tmp_path, model_class, policy_str): learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), ) - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class in [IQN, QRDQN]) # Reduce number of quantiles for faster tests - if model_class in [QRDQN]: + if model_class in [IQN, QRDQN]: kwargs["policy_kwargs"].update(dict(n_quantiles=20)) env = DummyVecEnv([lambda: env]) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index 1973d456..5a6c50fd 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -8,7 +8,7 @@ from stable_baselines3.common.preprocessing import get_flattened_obs_dim from stable_baselines3.common.torch_layers import BaseFeaturesExtractor -from sb3_contrib import QRDQN, TQC, MaskablePPO +from sb3_contrib import IQN, QRDQN, TQC, MaskablePPO from sb3_contrib.common.envs import InvalidActionEnvDiscrete from sb3_contrib.common.maskable.utils import get_action_masks @@ -45,7 +45,7 @@ def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor) return batch_norm.bias.clone(), batch_norm.running_mean.clone() -def clone_qrdqn_batch_norm_stats(model: QRDQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor): +def clone_iqn_qrdqn_batch_norm_stats(model: Union[IQN, QRDQN]) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor): """ Clone the bias and running mean from the quantile network and quantile-target network. :param model: @@ -85,7 +85,8 @@ def clone_on_policy_batch_norm(model: Union[MaskablePPO]) -> (th.Tensor, th.Tens CLONE_HELPERS = { - QRDQN: clone_qrdqn_batch_norm_stats, + QRDQN: clone_iqn_qrdqn_batch_norm_stats, + IQN: clone_iqn_qrdqn_batch_norm_stats, TQC: clone_tqc_batch_norm_stats, MaskablePPO: clone_on_policy_batch_norm, } @@ -125,8 +126,9 @@ def test_ppo_mask_train_eval_mode(): assert th.isclose(param_before, param_after).all() -def test_qrdqn_train_with_batch_norm(): - model = QRDQN( +@pytest.mark.parametrize("model_class", [IQN, QRDQN]) +def test_train_with_batch_norm(model_class): + model = model_class( "MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), @@ -140,7 +142,7 @@ def test_qrdqn_train_with_batch_norm(): quantile_net_running_mean_before, quantile_net_target_bias_before, quantile_net_target_running_mean_before, - ) = clone_qrdqn_batch_norm_stats(model) + ) = CLONE_HELPERS[model_class](model) model.learn(total_timesteps=200) # Force stats copy @@ -152,7 +154,7 @@ def test_qrdqn_train_with_batch_norm(): quantile_net_running_mean_after, quantile_net_target_bias_after, quantile_net_target_running_mean_after, - ) = clone_qrdqn_batch_norm_stats(model) + ) = CLONE_HELPERS[model_class](model) assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all() # Running stat should be copied even when tau=0 @@ -208,9 +210,9 @@ def test_tqc_train_with_batch_norm(): assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all() -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [IQN, QRDQN, TQC]) def test_offpolicy_collect_rollout_batch_norm(model_class): - if model_class in [QRDQN]: + if model_class in [IQN, QRDQN]: env_id = "CartPole-v1" else: env_id = "Pendulum-v1" @@ -239,19 +241,19 @@ def test_offpolicy_collect_rollout_batch_norm(model_class): assert th.isclose(param_before, param_after).all() -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [IQN, QRDQN, TQC]) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_predict_with_dropout_batch_norm(model_class, env_id): if env_id == "CartPole-v1": if model_class in [TQC]: return - elif model_class in [QRDQN]: + elif model_class in [IQN, QRDQN]: return model_kwargs = dict(seed=1) clone_helper = CLONE_HELPERS[model_class] - if model_class in [QRDQN, TQC]: + if model_class in [IQN, QRDQN, TQC]: model_kwargs["learning_starts"] = 0 else: model_kwargs["n_steps"] = 64