Skip to content

Commit 8335744

Browse files
update
1 parent 9f08fca commit 8335744

File tree

4 files changed

+151
-8
lines changed

4 files changed

+151
-8
lines changed

mujoco_robot_environments/lasa_data_generation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,17 @@ def preprocess_demo(demo_data):
5555
data = {}
5656
for demo_idx, demo in enumerate(demos):
5757
positions, velocities = preprocess_demo(demo)
58-
joint_positions, joint_velocities = [], []
58+
joint_positions, joint_velocities, joint_torques = [], [], []
5959
data[f"trajectory_{demo_idx}"] = {}
6060

6161
for idx, (target_pos, target_vel) in enumerate(zip(positions, velocities)):
6262
while True:
63-
pos, vel = env.move_to_draw_target(target_pos, target_vel)
63+
pos, vel, torque = env.move_to_draw_target(target_pos, target_vel)
6464

6565
if idx != 0:
6666
joint_positions.append(pos)
6767
joint_velocities.append(vel)
68+
joint_torques.append(torque)
6869

6970
# check if target is reached
7071
if env._robot.arm_controller.current_position_error() < 5e-3:
@@ -81,16 +82,18 @@ def preprocess_demo(demo_data):
8182
rgba=[1, 0, 0, 1]
8283
)
8384
env.passive_view.sync()
84-
85+
8586
data[f"trajectory_{demo_idx}"]["joint_positions"] = np.vstack(joint_positions)
8687
data[f"trajectory_{demo_idx}"]["joint_velocities"] = np.vstack(joint_velocities)
88+
data[f"trajectory_{demo_idx}"]["joint_torques"] = np.vstack(joint_torques)
8789

8890
# Save to HDF5
8991
with h5py.File("robot_trajectories.h5", "w") as f:
9092
for traj_name, data in data.items():
9193
group = f.create_group(traj_name)
9294
group.create_dataset("position", data=data["joint_positions"], compression="gzip")
9395
group.create_dataset("velocity", data=data["joint_velocities"], compression="gzip")
96+
group.create_dataset("torque", data=data["joint_torques"], compression="gzip")
9497

9598
env.close()
9699

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""Collecting robot demonstrations of LASA drawing dataset."""
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import jax.export as export
6+
import numpy as np
7+
import mujoco
8+
9+
import pyLasaDataset as lasa
10+
import h5py
11+
12+
from mujoco_robot_environments.tasks.lasa_draw import LasaDrawEnv
13+
14+
import hydra
15+
from hydra import compose, initialize
16+
17+
18+
if __name__=="__main__":
19+
# clear hydra global state to avoid conflicts with other hydra instances
20+
hydra.core.global_hydra.GlobalHydra.instance().clear()
21+
initialize(version_base=None, config_path="./config", job_name="rearrangement")
22+
23+
# add task configs
24+
COLOR_SEPARATING_CONFIG = compose(
25+
config_name="rearrangement",
26+
overrides=[
27+
"arena/props=colour_splitter",
28+
"simulation_tuning_mode=False",
29+
"robots/arm/actuator_config=position",
30+
]
31+
)
32+
33+
# load demo data to evaluate against
34+
def h5_to_dict(h5_group):
35+
"""
36+
Recursively convert an HDF5 group or dataset into a nested dictionary.
37+
"""
38+
result = {}
39+
for key, item in h5_group.items():
40+
if isinstance(item, h5py.Group): # If it's a group, recurse
41+
result[key] = h5_to_dict(item)
42+
elif isinstance(item, h5py.Dataset): # If it's a dataset, load it
43+
result[key] = item[:]
44+
return result
45+
46+
with h5py.File("./robot_trajectories.h5", "r") as f:
47+
data = h5_to_dict(f)
48+
49+
starting_joint_position = data["trajectory_0"]["position"][0]
50+
current_joint_position = starting_joint_position
51+
52+
# load the trained model
53+
with open("flax_apply_method.bin", "rb") as f:
54+
serialized_from_file = f.read()
55+
model = export.deserialize(serialized_from_file)
56+
dynamics_state = jnp.zeros((1, 5000))
57+
58+
# instantiate the task
59+
env = LasaDrawEnv(viewer=True, cfg=COLOR_SEPARATING_CONFIG)
60+
_, _, _, obs = env.reset(current_joint_position)
61+
62+
# draw the demonstration trajectory
63+
# Leverage demonstrations from LASA dataset
64+
import pyLasaDataset as lasa
65+
s_data = lasa.DataSet.Sshape
66+
demos = s_data.demos
67+
68+
def preprocess_demo(demo_data):
69+
pos = demo_data.pos
70+
vel = demo_data.vel
71+
72+
# scale position data
73+
pos_scaled = (pos / 200) + 0.2
74+
pos_scaled = pos_scaled[:,::4]
75+
positions = np.vstack([pos_scaled[1,:] + 0.2, -pos_scaled[0,:] + 0.2, np.repeat(0.55, pos_scaled.shape[1])]).T
76+
77+
# scale velocity data
78+
vel_scaled = (vel / 800)
79+
vel_scaled = vel_scaled[:,::4]
80+
velocities = np.vstack([vel_scaled[1,:], vel_scaled[0,:], np.repeat(0.0, pos_scaled.shape[1])]).T
81+
82+
return positions, velocities
83+
84+
positions, velocities = preprocess_demo(demos[0])
85+
86+
for target_position in positions:
87+
with env.passive_view.lock():
88+
env.passive_view.user_scn.ngeom += 1
89+
mujoco.mjv_initGeom(
90+
env.passive_view.user_scn.geoms[env.passive_view.user_scn.ngeom-1],
91+
type=mujoco.mjtGeom.mjGEOM_SPHERE,
92+
size=[0.001, 0, 0],
93+
pos=target_position,
94+
mat=np.eye(3).flatten(),
95+
rgba=[1, 0, 0, 1]
96+
)
97+
env.passive_view.sync()
98+
99+
while True:
100+
# make a prediction using the model
101+
position_target, dynamics_state = model.call(jnp.expand_dims(current_joint_position, axis=0), dynamics_state)
102+
103+
# pass the target to position actuators
104+
current_joint_position = env.move_to_joint_position_target(position_target[0, :7])
105+
106+
# get eef xpos
107+
eef_pos = env._robot.arm_controller.current_eef_position - np.array([0.0, 0.0, 0.1])
108+
109+
# draw the current state
110+
with env.passive_view.lock():
111+
env.passive_view.user_scn.ngeom += 1
112+
mujoco.mjv_initGeom(
113+
env.passive_view.user_scn.geoms[env.passive_view.user_scn.ngeom-1],
114+
type=mujoco.mjtGeom.mjGEOM_SPHERE,
115+
size=[0.001, 0, 0],
116+
pos=eef_pos,
117+
mat=np.eye(3).flatten(),
118+
rgba=[0, 1, 0, 1]
119+
)
120+
env.passive_view.sync()
121+
122+
env.close()
123+

mujoco_robot_environments/models/robot_arm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def __init__(
3838
self.eef_site = arm.attachment_site
3939
self.arm_joints = arm.joints
4040
self.arm_joint_ids = np.array(physics.bind(self.arm_joints).dofadr)
41-
41+
self.arm_actuators = arm.actuators
42+
4243
# set gripper and controller
4344
if gripper is not None:
4445
self.end_effector = gripper

mujoco_robot_environments/tasks/lasa_draw.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,19 @@ def model(self) -> mujoco.MjModel:
194194
def data(self) -> mujoco.MjData:
195195
return self.physics.data
196196

197-
def reset(self) -> dm_env.TimeStep:
197+
def reset(self, arm_configuration = None) -> dm_env.TimeStep:
198198
"""Resets the environment to an initial state and returns the first
199199
`TimeStep` of the new episode.
200200
"""
201-
# reset the lation instance
201+
if arm_configuration is None:
202+
arm_configuration = self.arm.named_configurations["home"]
203+
204+
# reset the simulation instance
202205
self._physics.reset()
203206

204207
# reset arm to home position
205208
# Note: for other envs we may want random sampling of initial arm positions
206-
self.arm.set_joint_angles(self._physics, self.arm.named_configurations["home"])
209+
self.arm.set_joint_angles(self._physics, arm_configuration)
207210
print(self.arm.named_configurations["home"])
208211

209212
# configure viewer
@@ -330,7 +333,20 @@ def move_to_draw_target(self, target_position, target_velocity):
330333
self.passive_view.sync()
331334

332335
# return joint data for recording
333-
return self._physics.bind(self._robot.arm_joints).qpos, self._physics.bind(self._robot.arm_joints).qvel
336+
return self._physics.bind(self._robot.arm_joints).qpos.copy(), self._physics.bind(self._robot.arm_joints).qvel.copy(), self._physics.bind(self._robot.arm_actuators).ctrl.copy()
337+
338+
def move_to_joint_position_target(self, target_position):
339+
"""
340+
Move to position and velocity target for drawing task.
341+
"""
342+
# step the simulation
343+
for _ in range(5):
344+
self._physics.set_control(target_position)
345+
self._physics.step()
346+
if self.passive_view is not None:
347+
self.passive_view.sync()
348+
349+
return self._physics.bind(self._robot.arm_joints).qpos.copy()
334350

335351

336352
if __name__=="__main__":

0 commit comments

Comments
 (0)