Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/mujoco_locomotion_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from argparse import ArgumentParser
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from mushroom_rl.algorithms.actor_critic import PPO
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import Ant, HalfCheetah, Hopper, Walker2D
from mushroom_rl.policy import GaussianTorchPolicy

from tqdm import trange


class Network(nn.Module):
def __init__(self, input_shape, output_shape, n_features, **kwargs):
super(Network, self).__init__()

n_input = input_shape[-1]
n_output = output_shape[0]

self._h1 = nn.Linear(n_input, n_features)
self._h2 = nn.Linear(n_features, n_features)
self._h3 = nn.Linear(n_features, n_output)

nn.init.xavier_uniform_(
self._h1.weight, gain=nn.init.calculate_gain("relu") / 10
)
nn.init.xavier_uniform_(
self._h2.weight, gain=nn.init.calculate_gain("relu") / 10
)
nn.init.xavier_uniform_(
self._h3.weight, gain=nn.init.calculate_gain("linear") / 10
)

def forward(self, state, **kwargs):
features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
features2 = F.relu(self._h2(features1))
a = self._h3(features2)

return a


def experiment(env, n_epochs, n_steps, n_episodes_test):
np.random.seed()

logger = Logger(PPO.__name__, results_dir=None)
logger.strong_line()
logger.info("Experiment Algorithm: " + PPO.__name__)

mdp = env()

actor_lr = 3e-4
critic_lr = 3e-4
n_features = 32
batch_size = 64
n_epochs_policy = 10
eps = 0.2
lam = 0.95
std_0 = 1.0
n_steps_per_fit = 2000

critic_params = dict(
network=Network,
optimizer={"class": optim.Adam, "params": {"lr": critic_lr}},
loss=F.mse_loss,
n_features=n_features,
batch_size=batch_size,
input_shape=mdp.info.observation_space.shape,
output_shape=(1,),
)

alg_params = dict(
actor_optimizer={"class": optim.Adam, "params": {"lr": actor_lr}},
n_epochs_policy=n_epochs_policy,
batch_size=batch_size,
eps_ppo=eps,
lam=lam,
critic_params=critic_params,
)

policy_params = dict(std_0=std_0, n_features=n_features)

policy = GaussianTorchPolicy(
Network,
mdp.info.observation_space.shape,
mdp.info.action_space.shape,
**policy_params,
)

agent = PPO(mdp.info, policy, **alg_params)

core = Core(agent, mdp)

dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy()

logger.epoch_info(0, J=J, R=R, entropy=E)

for it in trange(n_epochs, leave=False):
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy()

logger.epoch_info(it + 1, J=J, R=R, entropy=E)

logger.info("Press a button to visualize")
input()
core.evaluate(n_episodes=5, render=True)


if __name__ == "__main__":
envs = [Ant, HalfCheetah, Hopper, Walker2D]
for env in envs:
experiment(env=env, n_epochs=50, n_steps=30000, n_episodes_test=10)
4 changes: 2 additions & 2 deletions mushroom_rl/environments/mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def step(self, action):

absorbing = self.is_absorbing(cur_obs)
reward = self.reward(self._obs, action, cur_obs, absorbing)
info = self._create_info_dictionary(cur_obs)
info = self._create_info_dictionary(cur_obs, action)

self._obs = cur_obs

Expand Down Expand Up @@ -199,7 +199,7 @@ def _create_observation(self, obs):
"""
return obs

def _create_info_dictionary(self, obs):
def _create_info_dictionary(self, obs, action):
"""
This method can be overridden to create a custom info dictionary.

Expand Down
8 changes: 8 additions & 0 deletions mushroom_rl/environments/mujoco_envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from .ball_in_a_cup import BallInACup
from .air_hockey import AirHockeyHit, AirHockeyDefend, AirHockeyPrepare, AirHockeyRepel
from .ant import Ant
from .half_cheetah import HalfCheetah
from .hopper import Hopper
from .walker_2d import Walker2D

BallInACup.register()
AirHockeyHit.register()
AirHockeyDefend.register()
AirHockeyPrepare.register()
AirHockeyRepel.register()
Ant.register()
HalfCheetah.register()
Hopper.register()
Walker2D.register()
2 changes: 1 addition & 1 deletion mushroom_rl/environments/mujoco_envs/air_hockey/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _create_observation(self, state):
obs = super(AirHockeyDouble, self)._create_observation(state)
return np.append(obs, [self.robot_1_hit, self.robot_2_hit, self.has_bounce])

def _create_info_dictionary(self, obs):
def _create_info_dictionary(self, obs, action):
constraints = {"agent-1": {}, "agent-2":{}}

for i, key in enumerate(constraints.keys()):
Expand Down
3 changes: 1 addition & 2 deletions mushroom_rl/environments/mujoco_envs/air_hockey/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _create_observation(self, state):
obs = super(AirHockeySingle, self)._create_observation(state)
return np.append(obs, [self.has_hit, self.has_bounce])

def _create_info_dictionary(self, obs):
def _create_info_dictionary(self, obs, action):
constraints = {}
q_pos = self.obs_helper.get_joint_pos_from_obs(obs)
q_vel = self.obs_helper.get_joint_vel_from_obs(obs)
Expand Down Expand Up @@ -134,4 +134,3 @@ def _create_info_dictionary(self, obs):
constraints["joint_vel_constraints"][3:] = self.obs_helper.get_joint_vel_limits()[0] - q_vel

return constraints

203 changes: 203 additions & 0 deletions mushroom_rl/environments/mujoco_envs/ant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from pathlib import Path
from typing import Tuple

import numpy as np
from mushroom_rl.environments.mujoco import MuJoCo, ObservationType
from mushroom_rl.rl_utils.spaces import Box
import mujoco


class Ant(MuJoCo):
"""
The Ant MuJoCo environment as presented in:
"High-Dimensional Continuous Control Using Generalized Advantage Estimation". John Schulman et. al.. 2015.
and implemented in Gymnasium
"""

def __init__(
self,
gamma: float = 0.99,
horizon: int = 1000,
forward_reward_weight: float = 1.0,
ctrl_cost_weight: float = 0.5,
contact_cost_weight: float = 5e-4,
healthy_reward: float = 1.0,
terminate_when_unhealthy: bool = True,
healthy_z_range: Tuple[float, float] = (0.2, 1.0),
contact_force_range: Tuple[float, float] = (-1.0, 1.0),
reset_noise_scale: float = 0.1,
n_substeps: int = 5,
exclude_current_positions_from_observation: bool = True,
use_contact_forces: bool = False,
**viewer_params,
):
"""
Constructor.

"""
xml_path = (
Path(__file__).resolve().parent / "data" / "ant" / "model.xml"
).as_posix()

# This order is correct as specified in gymnasium
actuation_spec = [
"hip_4",
"ankle_4",
"hip_1",
"ankle_1",
"hip_2",
"ankle_2",
"hip_3",
"ankle_3",
]

observation_spec = [
("root_pose", "root", ObservationType.JOINT_POS),
("hip_1_pos", "hip_1", ObservationType.JOINT_POS),
("ankle_1_pos", "ankle_1", ObservationType.JOINT_POS),
("hip_2_pos", "hip_2", ObservationType.JOINT_POS),
("ankle_2_pos", "ankle_2", ObservationType.JOINT_POS),
("hip_3_pos", "hip_3", ObservationType.JOINT_POS),
("ankle_3_pos", "ankle_3", ObservationType.JOINT_POS),
("hip_4_pos", "hip_4", ObservationType.JOINT_POS),
("ankle_4_pos", "ankle_4", ObservationType.JOINT_POS),
("root_vel", "root", ObservationType.JOINT_VEL),
("hip_1_vel", "hip_1", ObservationType.JOINT_VEL),
("ankle_1_vel", "ankle_1", ObservationType.JOINT_VEL),
("hip_2_vel", "hip_2", ObservationType.JOINT_VEL),
("ankle_2_vel", "ankle_2", ObservationType.JOINT_VEL),
("hip_3_vel", "hip_3", ObservationType.JOINT_VEL),
("ankle_3_vel", "ankle_3", ObservationType.JOINT_VEL),
("hip_4_vel", "hip_4", ObservationType.JOINT_VEL),
("ankle_4_vel", "ankle_4", ObservationType.JOINT_VEL),
]

additional_data_spec = [
("torso_pos", "torso", ObservationType.BODY_POS),
("torso_vel", "torso", ObservationType.BODY_VEL_WORLD),
]

collision_groups = [
("torso", ["torso_geom"]),
("floor", ["floor"]),
]

self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._contact_cost_weight = contact_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._contact_force_range = contact_force_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self._use_contact_forces = use_contact_forces

super().__init__(
xml_file=xml_path,
gamma=gamma,
horizon=horizon,
observation_spec=observation_spec,
actuation_spec=actuation_spec,
collision_groups=collision_groups,
additional_data_spec=additional_data_spec,
n_substeps=n_substeps,
**viewer_params,
)

def _modify_mdp_info(self, mdp_info):
if self._exclude_current_positions_from_observation:
self.obs_helper.remove_obs("root_pose", 0)
self.obs_helper.remove_obs("root_pose", 1)
if self._use_contact_forces:
self.obs_helper.add_obs("collision_force", 6)
mdp_info = super()._modify_mdp_info(mdp_info)
mdp_info.observation_space = Box(*self.obs_helper.get_obs_limits())
return mdp_info

def _create_observation(self, obs):
obs = super()._create_observation(obs)
if self._use_contact_forces:
collision_force = self._get_collision_force("torso", "floor")
obs = np.concatenate([obs, collision_force])
return obs

def _is_finite(self):
states = self.get_states()
return np.isfinite(states).all()

def _is_within_z_range(self):
z_pos = self._read_data("torso_pos")[2]
min_z, max_z = self._healthy_z_range
return min_z <= z_pos <= max_z

def _is_healthy(self):
is_healthy = self._is_finite() and self._is_within_z_range()
return is_healthy

def is_absorbing(self, obs):
absorbing = self._terminate_when_unhealthy and not self._is_healthy()
return absorbing

def _get_healthy_reward(self, obs):
return (
self._terminate_when_unhealthy and self._is_healthy()
) * self._healthy_reward

def _get_forward_reward(self):
forward_reward = self._read_data("torso_vel")[3]
return self._forward_reward_weight * forward_reward

def _get_ctrl_cost(self, action):
ctrl_cost = np.sum(np.square(action))
return self._ctrl_cost_weight * ctrl_cost

def _get_contact_cost(self, obs):
collision_force = self.obs_helper.get_from_obs(obs, "collision_force")
contact_cost = np.sum(
np.square(np.clip(collision_force, *self._contact_force_range))
)
return self._contact_cost_weight * contact_cost

def reward(self, obs, action, next_obs, absorbing):
healthy_reward = self._get_healthy_reward(next_obs)
forward_reward = self._get_forward_reward()
cost = self._get_ctrl_cost(action)
if self._use_contact_forces:
contact_cost = self._get_contact_cost(next_obs)
cost += contact_cost
reward = healthy_reward + forward_reward - cost
return reward

def _generate_noise(self):
self._data.qpos[:] = self._data.qpos + np.random.uniform(
-self._reset_noise_scale, self._reset_noise_scale, size=self._model.nq
)

self._data.qvel[:] = (
self._data.qvel
+ self._reset_noise_scale * np.random.standard_normal(self._model.nv)
)

def setup(self, obs):
super().setup(obs)

self._generate_noise()

mujoco.mj_forward(self._model, self._data) # type: ignore

def _create_info_dictionary(self, obs, action):
info = {
"healthy_reward": self._get_healthy_reward(obs),
"forward_reward": self._get_forward_reward(),
}
info["ctrl_cost"] = self._get_ctrl_cost(action)
if self._use_contact_forces:
info["contact_cost"] = self._get_contact_cost(obs)
return info

def get_states(self):
"""Return the position and velocity joint states of the model"""
return np.concatenate([self._data.qpos.flat, self._data.qvel.flat])
Empty file.
Loading