Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion sb3_contrib/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def __init__(
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = False,
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
Expand All @@ -230,7 +232,14 @@ def __init__(
self.quantiles_total = n_quantiles * n_critics

for i in range(n_critics):
qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
qf_net = create_mlp(
features_dim + action_dim,
n_quantiles,
net_arch,
activation_fn,
dropout_rate=dropout_rate,
layer_norm=layer_norm,
)
qf_net = nn.Sequential(*qf_net)
self.add_module(f"qf{i}", qf_net)
self.q_networks.append(qf_net)
Expand Down Expand Up @@ -298,6 +307,9 @@ def __init__(
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = False,
# For the critic only
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
Expand Down Expand Up @@ -341,6 +353,8 @@ def __init__(
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
"dropout_rate": dropout_rate,
"layer_norm": layer_norm,
}
self.critic_kwargs.update(tqc_kwargs)
self.actor, self.actor_target = None, None
Expand Down
27 changes: 16 additions & 11 deletions sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_delay: int = 1,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
target_entropy: Union[str, float] = "auto",
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
self.target_update_interval = target_update_interval
self.ent_coef_optimizer = None
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
self.policy_delay = policy_delay

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -202,6 +204,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
actor_losses, critic_losses = [], []

for gradient_step in range(gradient_steps):
self._n_updates += 1
update_actor = self._n_updates % self.policy_delay == 0
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

Expand All @@ -219,8 +223,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
if update_actor:
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
else:
ent_coef = self.ent_coef_tensor

Expand All @@ -239,6 +244,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute and cut quantiles at the next state
# batch x nets x quantiles
# Note: in dropq dropout seems to be on for target net too
next_quantiles = self.critic_target(replay_data.next_observations, next_actions)

# Sort and drop top k quantiles to control overestimation.
Expand All @@ -264,23 +270,22 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
self.critic.optimizer.step()

# Compute actor loss
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - qf_pi).mean()
actor_losses.append(actor_loss.item())
if update_actor:
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - qf_pi).mean()
actor_losses.append(actor_loss.item())

# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()

# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

self._n_updates += gradient_steps

self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/ent_coef", np.mean(ent_coefs))
self.logger.record("train/actor_loss", np.mean(actor_losses))
Expand Down
11 changes: 11 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
from sb3_contrib.common.vec_env import AsyncEval


def test_dropq():
model = TQC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005),
verbose=1,
buffer_size=250,
)
model.learn(total_timesteps=300)


@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_tqc(ent_coef):
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
Expand Down