Skip to content

Commit 6c10c85

Browse files
authored
Mujoco manipulation (#163)
* Create mujoco_manipulation_ppo.py * Add data for MuJoCo manipulation envs * Add manipulation environments * Delete mushroom_rl/environments/mujoco_envs/data/peg_insertion.xml * Add peg insertion xml * Delete mushroom_rl/environments/mujoco_envs/data/panda/assets/hole.stl * Update viewer.py to include option to render coordinate frames * Add util for quat distance * Update mujoco_manipulation_ppo.py * Remove unused commented lines * Update panda.py
1 parent 445d142 commit 6c10c85

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+927580
-1
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import torch.optim as optim
6+
7+
from mushroom_rl.algorithms.actor_critic import PPO
8+
from mushroom_rl.core import Core, Logger
9+
from mushroom_rl.environments import Reach, Push, Pick, PegInsertion
10+
from mushroom_rl.policy import GaussianTorchPolicy
11+
from mushroom_rl.rl_utils.preprocessors import StandardizationPreprocessor
12+
13+
from tqdm import trange
14+
15+
16+
class Network(nn.Module):
17+
def __init__(self, input_shape, output_shape, n_features, **kwargs):
18+
super(Network, self).__init__()
19+
20+
n_input = input_shape[-1]
21+
n_output = output_shape[0]
22+
23+
self._h1 = nn.Linear(n_input, n_features)
24+
self._h2 = nn.Linear(n_features, n_features)
25+
self._h3 = nn.Linear(n_features, n_output)
26+
27+
nn.init.orthogonal_(self._h1.weight, gain=np.sqrt(2))
28+
nn.init.orthogonal_(self._h2.weight, gain=np.sqrt(2))
29+
nn.init.orthogonal_(self._h3.weight, gain=0.01)
30+
31+
nn.init.constant_(self._h1.bias, 0)
32+
nn.init.constant_(self._h2.bias, 0)
33+
nn.init.constant_(self._h3.bias, 0)
34+
35+
def forward(self, state, **kwargs):
36+
features1 = F.tanh(self._h1(torch.squeeze(state, 1).float()))
37+
features2 = F.tanh(self._h2(features1))
38+
a = self._h3(features2)
39+
40+
return a
41+
42+
43+
def experiment(env, n_epochs, n_steps, n_episodes_test):
44+
np.random.seed()
45+
46+
logger = Logger(PPO.__name__, results_dir=None)
47+
logger.strong_line()
48+
logger.info("Experiment Algorithm: " + PPO.__name__)
49+
50+
mdp = env()
51+
52+
actor_lr = 3e-4
53+
critic_lr = 3e-4
54+
n_features = 64
55+
batch_size = 64
56+
n_epochs_policy = 4
57+
eps = 0.2
58+
lam = 0.95
59+
std_0 = 1.0
60+
n_steps_per_fit = 2000
61+
62+
critic_params = dict(
63+
network=Network,
64+
optimizer={"class": optim.Adam, "params": {"lr": critic_lr}},
65+
loss=F.mse_loss,
66+
n_features=n_features,
67+
batch_size=batch_size,
68+
input_shape=mdp.info.observation_space.shape,
69+
output_shape=(1,),
70+
)
71+
72+
alg_params = dict(
73+
actor_optimizer={"class": optim.Adam, "params": {"lr": actor_lr}},
74+
n_epochs_policy=n_epochs_policy,
75+
batch_size=batch_size,
76+
eps_ppo=eps,
77+
lam=lam,
78+
critic_params=critic_params,
79+
)
80+
81+
policy_params = dict(std_0=std_0, n_features=n_features)
82+
83+
policy = GaussianTorchPolicy(
84+
Network,
85+
mdp.info.observation_space.shape,
86+
mdp.info.action_space.shape,
87+
**policy_params,
88+
)
89+
90+
agent = PPO(mdp.info, policy, **alg_params)
91+
92+
standardization_preprocessor = StandardizationPreprocessor(
93+
mdp.info, backend="numpy"
94+
)
95+
agent.add_core_preprocessor(standardization_preprocessor)
96+
97+
core = Core(agent, mdp)
98+
99+
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)
100+
101+
J = np.mean(dataset.discounted_return)
102+
R = np.mean(dataset.undiscounted_return)
103+
E = agent.policy.entropy().item()
104+
105+
logger.epoch_info(0, J=J, R=R, entropy=E)
106+
107+
for it in trange(n_epochs, leave=False):
108+
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
109+
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)
110+
111+
J = np.mean(dataset.discounted_return)
112+
R = np.mean(dataset.undiscounted_return)
113+
E = agent.policy.entropy().item()
114+
115+
logger.epoch_info(it + 1, J=J, R=R, entropy=E)
116+
117+
logger.info("Press a button to visualize")
118+
input()
119+
core.evaluate(n_episodes=5, render=True)
120+
121+
122+
if __name__ == "__main__":
123+
envs = [Reach, Push, Pick, PegInsertion]
124+
for env in envs:
125+
experiment(env, n_epochs=50, n_steps=100_000, n_episodes_test=10)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<mujoco model="table">
2+
<worldbody>
3+
<body name="cube" pos="0.5 0.0 0.024">
4+
<geom name="cube" type="box" size="0.024 0.024 0.024" mass="0.216" friction="1 0.03 0.003" solref="0.01 1"/>
5+
<freejoint name="cube"/>
6+
</body>
7+
</worldbody>
8+
</mujoco>

0 commit comments

Comments
 (0)