Skip to content

Commit 968f350

Browse files
update
1 parent 8335744 commit 968f350

File tree

4 files changed

+99
-56
lines changed

4 files changed

+99
-56
lines changed

mujoco_robot_environments/lasa_data_generation.py

Lines changed: 88 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,75 +18,112 @@
1818
initialize(version_base=None, config_path="./config", job_name="rearrangement")
1919

2020
# add task configs
21-
COLOR_SEPARATING_CONFIG = compose(
21+
CONFIG = compose(
2222
config_name="rearrangement",
2323
overrides=[
2424
"arena/props=colour_splitter",
2525
"simulation_tuning_mode=True"
2626
]
2727
)
2828

29-
# instantiate color separation task
30-
env = LasaDrawEnv(viewer=True, cfg=COLOR_SEPARATING_CONFIG)
31-
3229
# Leverage demonstrations from LASA dataset
3330
import pyLasaDataset as lasa
34-
s_data = lasa.DataSet.Sshape
35-
demos = s_data.demos
31+
shapes = ["CShape", "GShape", "JShape", "LShape", "NShape", "PShape", "RShape", "Sshape", "WShape", "Zshape"]
32+
33+
def transform_to_workspace(trajectories, workspace_bounds):
34+
"""
35+
Transform trajectories to fit within a workspace while maintaining aspect ratio.
36+
"""
37+
# Extract workspace
38+
workspace_x_min, workspace_x_max, workspace_y_min, workspace_y_max = workspace_bounds
3639

37-
def preprocess_demo(demo_data):
38-
pos = demo_data.pos
39-
vel = demo_data.vel
40+
# Extract range of demonstrations
41+
stack = np.hstack([d.pos for d in demos])
42+
original_x_min, original_y_min = np.min(stack, axis=1)
43+
original_x_max, original_y_max = np.max(stack, axis=1)
44+
45+
# Compute original width and height
46+
original_width = original_x_max - original_x_min
47+
original_height = original_y_max - original_y_min
48+
49+
# Compute workspace width and height
50+
workspace_width = workspace_x_max - workspace_x_min
51+
workspace_height = workspace_y_max - workspace_y_min
52+
53+
# Compute scaling factors
54+
scale_x = workspace_width / original_width
55+
scale_y = workspace_height / original_height
56+
57+
# Use the smaller scaling factor to maintain aspect ratio
58+
scale = min(scale_x, scale_y)
59+
60+
# Compute offsets to center trajectories in the workspace
61+
offset_x = workspace_x_min + (workspace_width - (original_width * scale)) / 2
62+
offset_y = workspace_y_min + (workspace_height - (original_height * scale)) / 2
63+
64+
# Transform each trajectory
65+
transformed_position_trajectories = []
66+
transformed_velocity_trajectories = []
67+
for traj in trajectories:
68+
pos = traj.pos
69+
vel = traj.vel
70+
71+
# Scale and shift
72+
transformed_x = (pos[0, :] - original_x_min) * scale + offset_x
73+
transformed_y = (pos[1, :] - original_y_min) * scale + offset_y
74+
z = np.repeat(0.55, pos.shape[1])
75+
transformed_position_trajectories.append(np.vstack((transformed_x, transformed_y, z)))
4076

41-
# scale position data
42-
pos_scaled = (pos / 200) + 0.2
43-
pos_scaled = pos_scaled[:,::4]
44-
positions = np.vstack([pos_scaled[1,:] + 0.2, -pos_scaled[0,:] + 0.2, np.repeat(0.55, pos_scaled.shape[1])]).T
77+
# scale
78+
vel_z = np.repeat(0.0, pos.shape[1])
79+
transformed_velocity_trajectories.append(np.vstack([vel * scale / 800, vel_z])) # TODO: formalize velocity scaling
80+
81+
return np.array(transformed_position_trajectories), np.array(transformed_velocity_trajectories)
4582

46-
# scale velocity data
47-
vel_scaled = (vel / 800)
48-
vel_scaled = vel_scaled[:,::4]
49-
velocities = np.vstack([vel_scaled[1,:], vel_scaled[0,:], np.repeat(0.0, pos_scaled.shape[1])]).T
83+
env = LasaDrawEnv(cfg=CONFIG)
84+
_, _, _, obs = env.reset()
5085

51-
return positions, velocities
52-
5386
# interactive control of robot with mocap body
54-
_, _, _, obs = env.reset()
5587
data = {}
56-
for demo_idx, demo in enumerate(demos):
57-
positions, velocities = preprocess_demo(demo)
58-
joint_positions, joint_velocities, joint_torques = [], [], []
59-
data[f"trajectory_{demo_idx}"] = {}
88+
for char in shapes:
89+
demos = lasa.DataSet.__getattr__(char).demos
90+
pos, vel = transform_to_workspace(demos, workspace_bounds=[0.3, 0.6, -0.3, 0.3])
91+
for demo_idx, (positions, velocities) in enumerate(zip(pos, vel)):
92+
print(f"Processing {char} demo {demo_idx}...")
93+
joint_positions, joint_velocities, joint_torques = [], [], []
94+
data[f"{char}_trajectory_{demo_idx}"] = {}
95+
96+
for idx, (target_pos, target_vel) in enumerate(zip(positions.T, velocities.T)):
97+
while True:
98+
pos, vel, torque = env.move_to_draw_target(target_pos, target_vel)
6099

61-
for idx, (target_pos, target_vel) in enumerate(zip(positions, velocities)):
62-
while True:
63-
pos, vel, torque = env.move_to_draw_target(target_pos, target_vel)
100+
if idx != 0:
101+
joint_positions.append(pos)
102+
joint_velocities.append(vel)
103+
joint_torques.append(torque)
64104

65-
if idx != 0:
66-
joint_positions.append(pos)
67-
joint_velocities.append(vel)
68-
joint_torques.append(torque)
105+
# check if target is reached
106+
if env._robot.arm_controller.current_position_error() < 5e-3:
107+
break
108+
109+
if CONFIG.viewer:
110+
with env.passive_view.lock():
111+
env.passive_view.user_scn.ngeom += 1
112+
mujoco.mjv_initGeom(
113+
env.passive_view.user_scn.geoms[env.passive_view.user_scn.ngeom - 1],
114+
type=mujoco.mjtGeom.mjGEOM_SPHERE,
115+
size=[0.001, 0, 0],
116+
pos=target_pos,
117+
mat=np.eye(3).flatten(),
118+
rgba=[1, 0, 0, 1]
119+
)
120+
env.passive_view.sync()
69121

70-
# check if target is reached
71-
if env._robot.arm_controller.current_position_error() < 5e-3:
72-
break
73-
74-
with env.passive_view.lock():
75-
env.passive_view.user_scn.ngeom += 1
76-
mujoco.mjv_initGeom(
77-
env.passive_view.user_scn.geoms[env.passive_view.user_scn.ngeom-1],
78-
type=mujoco.mjtGeom.mjGEOM_SPHERE,
79-
size=[0.001, 0, 0],
80-
pos=target_pos,
81-
mat=np.eye(3).flatten(),
82-
rgba=[1, 0, 0, 1]
83-
)
84-
env.passive_view.sync()
85-
86-
data[f"trajectory_{demo_idx}"]["joint_positions"] = np.vstack(joint_positions)
87-
data[f"trajectory_{demo_idx}"]["joint_velocities"] = np.vstack(joint_velocities)
88-
data[f"trajectory_{demo_idx}"]["joint_torques"] = np.vstack(joint_torques)
89-
122+
# env.close()
123+
data[f"{char}_trajectory_{demo_idx}"]["joint_positions"] = np.vstack(joint_positions)
124+
data[f"{char}_trajectory_{demo_idx}"]["joint_velocities"] = np.vstack(joint_velocities)
125+
data[f"{char}_trajectory_{demo_idx}"]["joint_torques"] = np.vstack(joint_torques)
126+
90127
# Save to HDF5
91128
with h5py.File("robot_trajectories.h5", "w") as f:
92129
for traj_name, data in data.items():

mujoco_robot_environments/lasa_policy_deployment.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,15 @@ def h5_to_dict(h5_group):
5050
current_joint_position = starting_joint_position
5151

5252
# load the trained model
53-
with open("flax_apply_method.bin", "rb") as f:
53+
# with open("ours_method.bin", "rb") as f:
54+
# serialized_from_file = f.read()
55+
# model = export.deserialize(serialized_from_file)
56+
# dynamics_state = jnp.zeros((1, 5000))
57+
58+
with open("feedforward_method.bin", "rb") as f:
5459
serialized_from_file = f.read()
5560
model = export.deserialize(serialized_from_file)
56-
dynamics_state = jnp.zeros((1, 5000))
57-
61+
5862
# instantiate the task
5963
env = LasaDrawEnv(viewer=True, cfg=COLOR_SEPARATING_CONFIG)
6064
_, _, _, obs = env.reset(current_joint_position)
@@ -98,7 +102,9 @@ def preprocess_demo(demo_data):
98102

99103
while True:
100104
# make a prediction using the model
101-
position_target, dynamics_state = model.call(jnp.expand_dims(current_joint_position, axis=0), dynamics_state)
105+
# position_target, dynamics_state = model.call(jnp.expand_dims(current_joint_position, axis=0), dynamics_state)
106+
position_target = model.call(jnp.expand_dims(current_joint_position, axis=0))
107+
102108

103109
# pass the target to position actuators
104110
current_joint_position = env.move_to_joint_position_target(position_target[0, :7])

mujoco_robot_environments/tasks/lasa_draw.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def reset(self, arm_configuration = None) -> dm_env.TimeStep:
207207
# reset arm to home position
208208
# Note: for other envs we may want random sampling of initial arm positions
209209
self.arm.set_joint_angles(self._physics, arm_configuration)
210-
print(self.arm.named_configurations["home"])
211210

212211
# configure viewer
213212
if self.has_viewer:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ mujoco_robot_environments = ['*.png']
3737
[tool.poetry.dependencies]
3838
python = "3.10.6"
3939
numpy = "^1.16.0"
40+
jax = {extras = ["cuda12"], version = "0.4.38"}
4041
mujoco = "3.2.6"
4142
mujoco_controllers = {path="./mujoco_robot_environments/mujoco_controllers", develop=true}
4243
brax = "^0.10.4"

0 commit comments

Comments
 (0)