From d47c1a10dd8ca8cfce1e57d9b868c5139876498b Mon Sep 17 00:00:00 2001 From: nissymori Date: Wed, 18 Dec 2024 23:04:27 +0900 Subject: [PATCH 1/2] fix xql for antmaze --- algos/xql.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/algos/xql.py b/algos/xql.py index c2ea10c..b5afb88 100644 --- a/algos/xql.py +++ b/algos/xql.py @@ -46,6 +46,7 @@ class XQLConfig(BaseModel): critic_lr: float = 3e-4 tau: float = 0.005 discount: float = 0.99 + opt_decay_schedule: bool = True # XQL SPECIFIC expectile: float = ( 0.7 # FYI: for Hopper-me, 0.5 produce better result. (antmaze: tau=0.9) @@ -165,6 +166,7 @@ class Transition(NamedTuple): rewards: jnp.ndarray next_observations: jnp.ndarray dones: jnp.ndarray + dones_float: jnp.ndarray def get_normalization(dataset: Transition) -> float: @@ -172,7 +174,7 @@ def get_normalization(dataset: Transition) -> float: dataset = jax.tree_util.tree_map(lambda x: np.array(x), dataset) returns = [] ret = 0 - for r, term in zip(dataset.rewards, dataset.dones): + for r, term in zip(dataset.rewards, dataset.dones_float): ret += r if term: returns.append(ret) @@ -205,7 +207,8 @@ def get_dataset( actions=jnp.array(dataset["actions"], dtype=jnp.float32), rewards=jnp.array(dataset["rewards"], dtype=jnp.float32), next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32), - dones=jnp.array(dones_float, dtype=jnp.float32), + dones=jnp.array(dataset["terminals"], dtype=jnp.float32), + dones_float=jnp.array(dones_float, dtype=jnp.float32), ) # normalize states obs_mean, obs_std = 0, 1 @@ -415,15 +418,15 @@ def value_loss_fn( def update_actor( self, train_state: XQLTrainState, batch: Transition, config: XQLConfig ) -> Tuple["XQLTrainState", Dict]: - def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray: - v = train_state.value.apply_fn(train_state.value.params, batch.observations) - q1, q2 = train_state.target_critic.apply_fn( - train_state.target_critic.params, batch.observations, batch.actions - ) - q = jnp.minimum(q1, q2) - exp_a = jnp.exp((q - v) * config.beta) - exp_a = jnp.minimum(exp_a, 100.0) + v = train_state.value.apply_fn(train_state.value.params, batch.observations) + q1, q2 = train_state.target_critic.apply_fn( + train_state.target_critic.params, batch.observations, batch.actions + ) + q = jnp.minimum(q1, q2) + exp_a = jnp.exp((q - v) * config.beta) + exp_a = jnp.minimum(exp_a, 100.0) + def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray: dist = train_state.actor.apply_fn(actor_params, batch.observations) log_probs = dist.log_prob(batch.actions) actor_loss = -(exp_a * log_probs).mean() @@ -493,8 +496,13 @@ def create_xql_train_state( action_dim=action_dim, log_std_min=-5.0, ) - schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps) - actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn)) + + if config.opt_decay_schedule: + schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps) + actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn)) + else: + actor_tx = optax.adam(learning_rate=config.actor_lr) + actor = TrainState.create( apply_fn=actor_model.apply, params=actor_model.init(actor_rng, observations), From c3c69c839e44343799863f3876878de821b9a58b Mon Sep 17 00:00:00 2001 From: nissymori Date: Thu, 19 Dec 2024 00:04:03 +0900 Subject: [PATCH 2/2] init --- algos/xql.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/algos/xql.py b/algos/xql.py index b5afb88..84837b4 100644 --- a/algos/xql.py +++ b/algos/xql.py @@ -54,6 +54,8 @@ class XQLConfig(BaseModel): beta: float = ( 3.0 # FYI: for Hopper-me, 6.0 produce better result. (antmaze: beta=10.0) ) + dropout_rate: Optional[float] = None + value_dropout_rate: Optional[float] = None # XQL SPECIFIC vanilla: bool = False # Of course, we do not use expectile loss sample_random_times: int = 0 # sample random times @@ -87,26 +89,30 @@ class MLP(nn.Module): activate_final: bool = False kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init() layer_norm: bool = False + dropout_rate: Optional[float] = None @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: for i, hidden_dims in enumerate(self.hidden_dims): x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x) if i + 1 < len(self.hidden_dims) or self.activate_final: if self.layer_norm: # Add layer norm after activation x = nn.LayerNorm()(x) x = self.activations(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not training) return x class Critic(nn.Module): hidden_dims: Sequence[int] activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + layer_norm: bool = False @nn.compact def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: inputs = jnp.concatenate([observations, actions], -1) - critic = MLP((*self.hidden_dims, 1), activations=self.activations)(inputs) + critic = MLP((*self.hidden_dims, 1), activations=self.activations, layer_norm=self.layer_norm)(inputs) return jnp.squeeze(critic, -1) @@ -126,10 +132,11 @@ def ensemblize(cls, num_qs, out_axes=0, **kwargs): class ValueCritic(nn.Module): hidden_dims: Sequence[int] layer_norm: bool = False + dropout_rate: Optional[float] = None @nn.compact def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: - critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm)(observations) + critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm, dropout_rate=self.dropout_rate)(observations) return jnp.squeeze(critic, -1) @@ -138,15 +145,16 @@ class GaussianPolicy(nn.Module): action_dim: int log_std_min: Optional[float] = -5.0 log_std_max: Optional[float] = 2 - + dropout_rate: Optional[float] = None @nn.compact def __call__( - self, observations: jnp.ndarray, temperature: float = 1.0 + self, observations: jnp.ndarray, temperature: float = 1.0, training: bool = False ) -> distrax.Distribution: outputs = MLP( self.hidden_dims, activate_final=True, - )(observations) + dropout_rate=self.dropout_rate, + )(observations, training) means = nn.Dense( self.action_dim, kernel_init=default_init() @@ -416,7 +424,7 @@ def value_loss_fn( @classmethod def update_actor( - self, train_state: XQLTrainState, batch: Transition, config: XQLConfig + self, train_state: XQLTrainState, batch: Transition, rng, config: XQLConfig ) -> Tuple["XQLTrainState", Dict]: v = train_state.value.apply_fn(train_state.value.params, batch.observations) q1, q2 = train_state.target_critic.apply_fn( @@ -427,7 +435,7 @@ def update_actor( exp_a = jnp.minimum(exp_a, 100.0) def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray: - dist = train_state.actor.apply_fn(actor_params, batch.observations) + dist = train_state.actor.apply_fn(actor_params, batch.observations, training=True, rngs={"dropout": rng}) log_probs = dist.log_prob(batch.actions) actor_loss = -(exp_a * log_probs).mean() return actor_loss @@ -450,11 +458,11 @@ def update_n_times( ) batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) - rng, subkey = jax.random.split(rng) + rng, value_rng, actor_rng = jax.random.split(rng, 3) train_state, value_loss = self.update_value( - train_state, batch, subkey, config + train_state, batch, value_rng, config ) - train_state, actor_loss = self.update_actor(train_state, batch, config) + train_state, actor_loss = self.update_actor(train_state, batch, actor_rng, config) train_state, critic_loss = self.update_critic(train_state, batch, config) new_target_critic = target_update( train_state.critic, train_state.target_critic, config.tau @@ -495,6 +503,7 @@ def create_xql_train_state( config.hidden_dims, action_dim=action_dim, log_std_min=-5.0, + dropout_rate=config.dropout_rate, ) if config.opt_decay_schedule: @@ -521,7 +530,7 @@ def create_xql_train_state( tx=optax.adam(learning_rate=config.critic_lr), ) # initialize value - value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm) + value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm, dropout_rate=config.value_dropout_rate) value = TrainState.create( apply_fn=value_model.apply, params=value_model.init(value_rng, observations),