-
Notifications
You must be signed in to change notification settings - Fork 155
WIP: Add MuJoCo Locomotion Environments #156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
66b04b3
Add Ant environment
noahfarr 16bd46d
Add HalfCheetah environment
noahfarr 45615f8
Add HalfCheetah example
noahfarr 2791921
Add Ant xml model
noahfarr 1d50a9a
Add Hopper environment
noahfarr 5ffafa9
Add Walker2D environment
noahfarr cdd3a81
Add __init__.py to data
noahfarr dea712f
Finish locomotion envs
noahfarr 25e2ae3
Merge branch 'MushroomRL:dev' into locomotion
noahfarr 4a57acc
Add new mujoco envs to __init__
noahfarr 00180f1
Revert formatting changes
noahfarr 05ee37c
Move mujoco locomotion examples to one file
noahfarr 1419b3f
Add tests for locomotion environments
noahfarr d92f20d
Fix minor details in example
noahfarr f051704
Update test_locomotion.py
noahfarr b9c7c6c
Comment out render for testing
noahfarr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| from argparse import ArgumentParser | ||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| import torch.optim as optim | ||
|
|
||
| from mushroom_rl.algorithms.actor_critic import PPO | ||
| from mushroom_rl.core import Core, Logger | ||
| from mushroom_rl.environments import Ant, HalfCheetah, Hopper, Walker2D | ||
| from mushroom_rl.policy import GaussianTorchPolicy | ||
|
|
||
| from tqdm import trange | ||
|
|
||
|
|
||
| class Network(nn.Module): | ||
| def __init__(self, input_shape, output_shape, n_features, **kwargs): | ||
| super(Network, self).__init__() | ||
|
|
||
| n_input = input_shape[-1] | ||
| n_output = output_shape[0] | ||
|
|
||
| self._h1 = nn.Linear(n_input, n_features) | ||
| self._h2 = nn.Linear(n_features, n_features) | ||
| self._h3 = nn.Linear(n_features, n_output) | ||
|
|
||
| nn.init.xavier_uniform_( | ||
| self._h1.weight, gain=nn.init.calculate_gain("relu") / 10 | ||
| ) | ||
| nn.init.xavier_uniform_( | ||
| self._h2.weight, gain=nn.init.calculate_gain("relu") / 10 | ||
| ) | ||
| nn.init.xavier_uniform_( | ||
| self._h3.weight, gain=nn.init.calculate_gain("linear") / 10 | ||
| ) | ||
|
|
||
| def forward(self, state, **kwargs): | ||
| features1 = F.relu(self._h1(torch.squeeze(state, 1).float())) | ||
| features2 = F.relu(self._h2(features1)) | ||
| a = self._h3(features2) | ||
|
|
||
| return a | ||
|
|
||
|
|
||
| def experiment(env, n_epochs, n_steps, n_episodes_test): | ||
| np.random.seed() | ||
|
|
||
| logger = Logger(PPO.__name__, results_dir=None) | ||
| logger.strong_line() | ||
| logger.info("Experiment Algorithm: " + PPO.__name__) | ||
|
|
||
| mdp = env() | ||
|
|
||
| actor_lr = 3e-4 | ||
| critic_lr = 3e-4 | ||
| n_features = 32 | ||
| batch_size = 64 | ||
| n_epochs_policy = 10 | ||
| eps = 0.2 | ||
| lam = 0.95 | ||
| std_0 = 1.0 | ||
| n_steps_per_fit = 2000 | ||
|
|
||
| critic_params = dict( | ||
| network=Network, | ||
| optimizer={"class": optim.Adam, "params": {"lr": critic_lr}}, | ||
| loss=F.mse_loss, | ||
| n_features=n_features, | ||
| batch_size=batch_size, | ||
| input_shape=mdp.info.observation_space.shape, | ||
| output_shape=(1,), | ||
| ) | ||
|
|
||
| alg_params = dict( | ||
| actor_optimizer={"class": optim.Adam, "params": {"lr": actor_lr}}, | ||
| n_epochs_policy=n_epochs_policy, | ||
| batch_size=batch_size, | ||
| eps_ppo=eps, | ||
| lam=lam, | ||
| critic_params=critic_params, | ||
| ) | ||
|
|
||
| policy_params = dict(std_0=std_0, n_features=n_features) | ||
|
|
||
| policy = GaussianTorchPolicy( | ||
| Network, | ||
| mdp.info.observation_space.shape, | ||
| mdp.info.action_space.shape, | ||
| **policy_params, | ||
| ) | ||
|
|
||
| agent = PPO(mdp.info, policy, **alg_params) | ||
|
|
||
| core = Core(agent, mdp) | ||
|
|
||
| dataset = core.evaluate(n_episodes=n_episodes_test, render=False) | ||
|
|
||
| J = np.mean(dataset.discounted_return) | ||
| R = np.mean(dataset.undiscounted_return) | ||
| E = agent.policy.entropy() | ||
|
|
||
| logger.epoch_info(0, J=J, R=R, entropy=E) | ||
|
|
||
| for it in trange(n_epochs, leave=False): | ||
| core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit) | ||
| dataset = core.evaluate(n_episodes=n_episodes_test, render=False) | ||
|
|
||
| J = np.mean(dataset.discounted_return) | ||
| R = np.mean(dataset.undiscounted_return) | ||
| E = agent.policy.entropy() | ||
|
|
||
| logger.epoch_info(it + 1, J=J, R=R, entropy=E) | ||
|
|
||
| logger.info("Press a button to visualize") | ||
| input() | ||
| core.evaluate(n_episodes=5, render=True) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| envs = [Ant, HalfCheetah, Hopper, Walker2D] | ||
| for env in envs: | ||
| experiment(env=env, n_epochs=50, n_steps=30000, n_episodes_test=10) |
noahfarr marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,16 @@ | ||
| from .ball_in_a_cup import BallInACup | ||
| from .air_hockey import AirHockeyHit, AirHockeyDefend, AirHockeyPrepare, AirHockeyRepel | ||
| from .ant import Ant | ||
| from .half_cheetah import HalfCheetah | ||
| from .hopper import Hopper | ||
| from .walker_2d import Walker2D | ||
|
|
||
| BallInACup.register() | ||
| AirHockeyHit.register() | ||
| AirHockeyDefend.register() | ||
| AirHockeyPrepare.register() | ||
| AirHockeyRepel.register() | ||
| Ant.register() | ||
| HalfCheetah.register() | ||
| Hopper.register() | ||
| Walker2D.register() |
noahfarr marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
noahfarr marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| from pathlib import Path | ||
| from typing import Tuple | ||
|
|
||
| import numpy as np | ||
| from mushroom_rl.environments.mujoco import MuJoCo, ObservationType | ||
| from mushroom_rl.rl_utils.spaces import Box | ||
| import mujoco | ||
|
|
||
|
|
||
| class Ant(MuJoCo): | ||
| """ | ||
| The Ant MuJoCo environment as presented in: | ||
| "High-Dimensional Continuous Control Using Generalized Advantage Estimation". John Schulman et. al.. 2015. | ||
| and implemented in Gymnasium | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| gamma: float = 0.99, | ||
| horizon: int = 1000, | ||
| forward_reward_weight: float = 1.0, | ||
| ctrl_cost_weight: float = 0.5, | ||
| contact_cost_weight: float = 5e-4, | ||
| healthy_reward: float = 1.0, | ||
| terminate_when_unhealthy: bool = True, | ||
| healthy_z_range: Tuple[float, float] = (0.2, 1.0), | ||
| contact_force_range: Tuple[float, float] = (-1.0, 1.0), | ||
| reset_noise_scale: float = 0.1, | ||
| n_substeps: int = 5, | ||
| exclude_current_positions_from_observation: bool = True, | ||
| use_contact_forces: bool = False, | ||
| **viewer_params, | ||
| ): | ||
| """ | ||
| Constructor. | ||
|
|
||
| """ | ||
| xml_path = ( | ||
| Path(__file__).resolve().parent / "data" / "ant" / "model.xml" | ||
| ).as_posix() | ||
|
|
||
| # This order is correct as specified in gymnasium | ||
| actuation_spec = [ | ||
| "hip_4", | ||
| "ankle_4", | ||
| "hip_1", | ||
| "ankle_1", | ||
| "hip_2", | ||
| "ankle_2", | ||
| "hip_3", | ||
| "ankle_3", | ||
| ] | ||
|
|
||
| observation_spec = [ | ||
| ("root_pose", "root", ObservationType.JOINT_POS), | ||
| ("hip_1_pos", "hip_1", ObservationType.JOINT_POS), | ||
| ("ankle_1_pos", "ankle_1", ObservationType.JOINT_POS), | ||
| ("hip_2_pos", "hip_2", ObservationType.JOINT_POS), | ||
| ("ankle_2_pos", "ankle_2", ObservationType.JOINT_POS), | ||
| ("hip_3_pos", "hip_3", ObservationType.JOINT_POS), | ||
| ("ankle_3_pos", "ankle_3", ObservationType.JOINT_POS), | ||
| ("hip_4_pos", "hip_4", ObservationType.JOINT_POS), | ||
| ("ankle_4_pos", "ankle_4", ObservationType.JOINT_POS), | ||
| ("root_vel", "root", ObservationType.JOINT_VEL), | ||
| ("hip_1_vel", "hip_1", ObservationType.JOINT_VEL), | ||
| ("ankle_1_vel", "ankle_1", ObservationType.JOINT_VEL), | ||
| ("hip_2_vel", "hip_2", ObservationType.JOINT_VEL), | ||
| ("ankle_2_vel", "ankle_2", ObservationType.JOINT_VEL), | ||
| ("hip_3_vel", "hip_3", ObservationType.JOINT_VEL), | ||
| ("ankle_3_vel", "ankle_3", ObservationType.JOINT_VEL), | ||
| ("hip_4_vel", "hip_4", ObservationType.JOINT_VEL), | ||
| ("ankle_4_vel", "ankle_4", ObservationType.JOINT_VEL), | ||
| ] | ||
|
|
||
| additional_data_spec = [ | ||
| ("torso_pos", "torso", ObservationType.BODY_POS), | ||
| ("torso_vel", "torso", ObservationType.BODY_VEL_WORLD), | ||
| ] | ||
|
|
||
| collision_groups = [ | ||
| ("torso", ["torso_geom"]), | ||
| ("floor", ["floor"]), | ||
| ] | ||
|
|
||
| self._forward_reward_weight = forward_reward_weight | ||
| self._ctrl_cost_weight = ctrl_cost_weight | ||
| self._contact_cost_weight = contact_cost_weight | ||
| self._healthy_reward = healthy_reward | ||
| self._terminate_when_unhealthy = terminate_when_unhealthy | ||
| self._healthy_z_range = healthy_z_range | ||
| self._contact_force_range = contact_force_range | ||
| self._reset_noise_scale = reset_noise_scale | ||
| self._exclude_current_positions_from_observation = ( | ||
| exclude_current_positions_from_observation | ||
| ) | ||
| self._use_contact_forces = use_contact_forces | ||
|
|
||
| super().__init__( | ||
| xml_file=xml_path, | ||
| gamma=gamma, | ||
| horizon=horizon, | ||
| observation_spec=observation_spec, | ||
| actuation_spec=actuation_spec, | ||
| collision_groups=collision_groups, | ||
| additional_data_spec=additional_data_spec, | ||
| n_substeps=n_substeps, | ||
| **viewer_params, | ||
| ) | ||
|
|
||
| def _modify_mdp_info(self, mdp_info): | ||
| if self._exclude_current_positions_from_observation: | ||
| self.obs_helper.remove_obs("root_pose", 0) | ||
| self.obs_helper.remove_obs("root_pose", 1) | ||
| if self._use_contact_forces: | ||
| self.obs_helper.add_obs("collision_force", 6) | ||
| mdp_info = super()._modify_mdp_info(mdp_info) | ||
| mdp_info.observation_space = Box(*self.obs_helper.get_obs_limits()) | ||
| return mdp_info | ||
|
|
||
| def _create_observation(self, obs): | ||
| obs = super()._create_observation(obs) | ||
| if self._use_contact_forces: | ||
| collision_force = self._get_collision_force("torso", "floor") | ||
| obs = np.concatenate([obs, collision_force]) | ||
| return obs | ||
|
|
||
| def _is_finite(self): | ||
| states = self.get_states() | ||
| return np.isfinite(states).all() | ||
|
|
||
| def _is_within_z_range(self): | ||
| z_pos = self._read_data("torso_pos")[2] | ||
| min_z, max_z = self._healthy_z_range | ||
| return min_z <= z_pos <= max_z | ||
|
|
||
| def _is_healthy(self): | ||
| is_healthy = self._is_finite() and self._is_within_z_range() | ||
| return is_healthy | ||
|
|
||
| def is_absorbing(self, obs): | ||
| absorbing = self._terminate_when_unhealthy and not self._is_healthy() | ||
| return absorbing | ||
|
|
||
| def _get_healthy_reward(self, obs): | ||
| return ( | ||
| self._terminate_when_unhealthy and self._is_healthy() | ||
| ) * self._healthy_reward | ||
|
|
||
| def _get_forward_reward(self): | ||
| forward_reward = self._read_data("torso_vel")[3] | ||
| return self._forward_reward_weight * forward_reward | ||
|
|
||
| def _get_ctrl_cost(self, action): | ||
| ctrl_cost = np.sum(np.square(action)) | ||
| return self._ctrl_cost_weight * ctrl_cost | ||
|
|
||
| def _get_contact_cost(self, obs): | ||
| collision_force = self.obs_helper.get_from_obs(obs, "collision_force") | ||
| contact_cost = np.sum( | ||
| np.square(np.clip(collision_force, *self._contact_force_range)) | ||
| ) | ||
| return self._contact_cost_weight * contact_cost | ||
|
|
||
| def reward(self, obs, action, next_obs, absorbing): | ||
| healthy_reward = self._get_healthy_reward(next_obs) | ||
| forward_reward = self._get_forward_reward() | ||
| cost = self._get_ctrl_cost(action) | ||
| if self._use_contact_forces: | ||
| contact_cost = self._get_contact_cost(next_obs) | ||
| cost += contact_cost | ||
| reward = healthy_reward + forward_reward - cost | ||
| return reward | ||
|
|
||
| def _generate_noise(self): | ||
| self._data.qpos[:] = self._data.qpos + np.random.uniform( | ||
| -self._reset_noise_scale, self._reset_noise_scale, size=self._model.nq | ||
| ) | ||
|
|
||
| self._data.qvel[:] = ( | ||
| self._data.qvel | ||
| + self._reset_noise_scale * np.random.standard_normal(self._model.nv) | ||
| ) | ||
|
|
||
| def setup(self, obs): | ||
| super().setup(obs) | ||
|
|
||
| self._generate_noise() | ||
|
|
||
| mujoco.mj_forward(self._model, self._data) # type: ignore | ||
|
|
||
| def _create_info_dictionary(self, obs, action): | ||
| info = { | ||
| "healthy_reward": self._get_healthy_reward(obs), | ||
| "forward_reward": self._get_forward_reward(), | ||
| } | ||
| info["ctrl_cost"] = self._get_ctrl_cost(action) | ||
| if self._use_contact_forces: | ||
| info["contact_cost"] = self._get_contact_cost(obs) | ||
| return info | ||
|
|
||
| def get_states(self): | ||
| """Return the position and velocity joint states of the model""" | ||
| return np.concatenate([self._data.qpos.flat, self._data.qvel.flat]) |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.