Skip to content

Commit 47f0cdd

Browse files
update
1 parent 968f350 commit 47f0cdd

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

mujoco_robot_environments/lasa_data_generation.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ def transform_to_workspace(trajectories, workspace_bounds):
8686
# interactive control of robot with mocap body
8787
data = {}
8888
for char in shapes:
89+
data[f"{char}"] = {}
8990
demos = lasa.DataSet.__getattr__(char).demos
9091
pos, vel = transform_to_workspace(demos, workspace_bounds=[0.3, 0.6, -0.3, 0.3])
9192
for demo_idx, (positions, velocities) in enumerate(zip(pos, vel)):
93+
data[f"{char}"][f"trajectory_{demo_idx}"] = {}
9294
print(f"Processing {char} demo {demo_idx}...")
9395
joint_positions, joint_velocities, joint_torques = [], [], []
94-
data[f"{char}_trajectory_{demo_idx}"] = {}
95-
9696
for idx, (target_pos, target_vel) in enumerate(zip(positions.T, velocities.T)):
9797
while True:
9898
pos, vel, torque = env.move_to_draw_target(target_pos, target_vel)
@@ -120,17 +120,19 @@ def transform_to_workspace(trajectories, workspace_bounds):
120120
env.passive_view.sync()
121121

122122
# env.close()
123-
data[f"{char}_trajectory_{demo_idx}"]["joint_positions"] = np.vstack(joint_positions)
124-
data[f"{char}_trajectory_{demo_idx}"]["joint_velocities"] = np.vstack(joint_velocities)
125-
data[f"{char}_trajectory_{demo_idx}"]["joint_torques"] = np.vstack(joint_torques)
126-
127-
# Save to HDF5
128-
with h5py.File("robot_trajectories.h5", "w") as f:
129-
for traj_name, data in data.items():
130-
group = f.create_group(traj_name)
131-
group.create_dataset("position", data=data["joint_positions"], compression="gzip")
132-
group.create_dataset("velocity", data=data["joint_velocities"], compression="gzip")
133-
group.create_dataset("torque", data=data["joint_torques"], compression="gzip")
123+
data[f"{char}"][f"trajectory_{demo_idx}"]["joint_positions"] = np.vstack(joint_positions)
124+
data[f"{char}"][f"trajectory_{demo_idx}"]["joint_velocities"] = np.vstack(joint_velocities)
125+
data[f"{char}"][f"trajectory_{demo_idx}"]["joint_torques"] = np.vstack(joint_torques)
134126

127+
128+
with h5py.File("robot_trajectories.h5", "w") as f:
129+
for character, trajectories in data.items():
130+
char_group = f.create_group(character) # Create a group for each character
131+
for traj_id, traj_data in trajectories.items():
132+
traj_group = char_group.create_group(traj_id) # Create a group for each trajectory ID
133+
traj_group.create_dataset("position", data=traj_data["joint_positions"], compression="gzip")
134+
traj_group.create_dataset("velocity", data=traj_data["joint_velocities"], compression="gzip")
135+
traj_group.create_dataset("torque", data=traj_data["joint_torques"], compression="gzip")
136+
135137
env.close()
136138

0 commit comments

Comments
 (0)