Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions algos/xql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,16 @@ class XQLConfig(BaseModel):
critic_lr: float = 3e-4
tau: float = 0.005
discount: float = 0.99
opt_decay_schedule: bool = True
# XQL SPECIFIC
expectile: float = (
0.7 # FYI: for Hopper-me, 0.5 produce better result. (antmaze: tau=0.9)
)
beta: float = (
3.0 # FYI: for Hopper-me, 6.0 produce better result. (antmaze: beta=10.0)
)
dropout_rate: Optional[float] = None
value_dropout_rate: Optional[float] = None
# XQL SPECIFIC
vanilla: bool = False # Of course, we do not use expectile loss
sample_random_times: int = 0 # sample random times
Expand Down Expand Up @@ -86,26 +89,30 @@ class MLP(nn.Module):
activate_final: bool = False
kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init()
layer_norm: bool = False
dropout_rate: Optional[float] = None

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
for i, hidden_dims in enumerate(self.hidden_dims):
x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x)
if i + 1 < len(self.hidden_dims) or self.activate_final:
if self.layer_norm: # Add layer norm after activation
x = nn.LayerNorm()(x)
x = self.activations(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not training)
return x


class Critic(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
layer_norm: bool = False

@nn.compact
def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
inputs = jnp.concatenate([observations, actions], -1)
critic = MLP((*self.hidden_dims, 1), activations=self.activations)(inputs)
critic = MLP((*self.hidden_dims, 1), activations=self.activations, layer_norm=self.layer_norm)(inputs)
return jnp.squeeze(critic, -1)


Expand All @@ -125,10 +132,11 @@ def ensemblize(cls, num_qs, out_axes=0, **kwargs):
class ValueCritic(nn.Module):
hidden_dims: Sequence[int]
layer_norm: bool = False
dropout_rate: Optional[float] = None

@nn.compact
def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm)(observations)
critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm, dropout_rate=self.dropout_rate)(observations)
return jnp.squeeze(critic, -1)


Expand All @@ -137,15 +145,16 @@ class GaussianPolicy(nn.Module):
action_dim: int
log_std_min: Optional[float] = -5.0
log_std_max: Optional[float] = 2

dropout_rate: Optional[float] = None
@nn.compact
def __call__(
self, observations: jnp.ndarray, temperature: float = 1.0
self, observations: jnp.ndarray, temperature: float = 1.0, training: bool = False
) -> distrax.Distribution:
outputs = MLP(
self.hidden_dims,
activate_final=True,
)(observations)
dropout_rate=self.dropout_rate,
)(observations, training)

means = nn.Dense(
self.action_dim, kernel_init=default_init()
Expand All @@ -165,14 +174,15 @@ class Transition(NamedTuple):
rewards: jnp.ndarray
next_observations: jnp.ndarray
dones: jnp.ndarray
dones_float: jnp.ndarray


def get_normalization(dataset: Transition) -> float:
# into numpy.ndarray
dataset = jax.tree_util.tree_map(lambda x: np.array(x), dataset)
returns = []
ret = 0
for r, term in zip(dataset.rewards, dataset.dones):
for r, term in zip(dataset.rewards, dataset.dones_float):
ret += r
if term:
returns.append(ret)
Expand Down Expand Up @@ -205,7 +215,8 @@ def get_dataset(
actions=jnp.array(dataset["actions"], dtype=jnp.float32),
rewards=jnp.array(dataset["rewards"], dtype=jnp.float32),
next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32),
dones=jnp.array(dones_float, dtype=jnp.float32),
dones=jnp.array(dataset["terminals"], dtype=jnp.float32),
dones_float=jnp.array(dones_float, dtype=jnp.float32),
)
# normalize states
obs_mean, obs_std = 0, 1
Expand Down Expand Up @@ -413,18 +424,18 @@ def value_loss_fn(

@classmethod
def update_actor(
self, train_state: XQLTrainState, batch: Transition, config: XQLConfig
self, train_state: XQLTrainState, batch: Transition, rng, config: XQLConfig
) -> Tuple["XQLTrainState", Dict]:
def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
v = train_state.value.apply_fn(train_state.value.params, batch.observations)
q1, q2 = train_state.target_critic.apply_fn(
train_state.target_critic.params, batch.observations, batch.actions
)
q = jnp.minimum(q1, q2)
exp_a = jnp.exp((q - v) * config.beta)
exp_a = jnp.minimum(exp_a, 100.0)
v = train_state.value.apply_fn(train_state.value.params, batch.observations)
q1, q2 = train_state.target_critic.apply_fn(
train_state.target_critic.params, batch.observations, batch.actions
)
q = jnp.minimum(q1, q2)
exp_a = jnp.exp((q - v) * config.beta)
exp_a = jnp.minimum(exp_a, 100.0)

dist = train_state.actor.apply_fn(actor_params, batch.observations)
def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
dist = train_state.actor.apply_fn(actor_params, batch.observations, training=True, rngs={"dropout": rng})
log_probs = dist.log_prob(batch.actions)
actor_loss = -(exp_a * log_probs).mean()
return actor_loss
Expand All @@ -447,11 +458,11 @@ def update_n_times(
)
batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)

rng, subkey = jax.random.split(rng)
rng, value_rng, actor_rng = jax.random.split(rng, 3)
train_state, value_loss = self.update_value(
train_state, batch, subkey, config
train_state, batch, value_rng, config
)
train_state, actor_loss = self.update_actor(train_state, batch, config)
train_state, actor_loss = self.update_actor(train_state, batch, actor_rng, config)
train_state, critic_loss = self.update_critic(train_state, batch, config)
new_target_critic = target_update(
train_state.critic, train_state.target_critic, config.tau
Expand Down Expand Up @@ -492,9 +503,15 @@ def create_xql_train_state(
config.hidden_dims,
action_dim=action_dim,
log_std_min=-5.0,
dropout_rate=config.dropout_rate,
)
schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps)
actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn))

if config.opt_decay_schedule:
schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps)
actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn))
else:
actor_tx = optax.adam(learning_rate=config.actor_lr)

actor = TrainState.create(
apply_fn=actor_model.apply,
params=actor_model.init(actor_rng, observations),
Expand All @@ -513,7 +530,7 @@ def create_xql_train_state(
tx=optax.adam(learning_rate=config.critic_lr),
)
# initialize value
value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm)
value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm, dropout_rate=config.value_dropout_rate)
value = TrainState.create(
apply_fn=value_model.apply,
params=value_model.init(value_rng, observations),
Expand Down