1+ """Collecting robot demonstrations of LASA drawing dataset."""
2+
3+ import jax
4+ import jax .numpy as jnp
5+ import jax .export as export
6+ import numpy as np
7+ import mujoco
8+
9+ import pyLasaDataset as lasa
10+ import h5py
11+
12+ from mujoco_robot_environments .tasks .lasa_draw import LasaDrawEnv
13+
14+ import hydra
15+ from hydra import compose , initialize
16+
17+
18+ if __name__ == "__main__" :
19+ # clear hydra global state to avoid conflicts with other hydra instances
20+ hydra .core .global_hydra .GlobalHydra .instance ().clear ()
21+ initialize (version_base = None , config_path = "./config" , job_name = "rearrangement" )
22+
23+ # add task configs
24+ COLOR_SEPARATING_CONFIG = compose (
25+ config_name = "rearrangement" ,
26+ overrides = [
27+ "arena/props=colour_splitter" ,
28+ "simulation_tuning_mode=False" ,
29+ "robots/arm/actuator_config=position" ,
30+ ]
31+ )
32+
33+ # load demo data to evaluate against
34+ def h5_to_dict (h5_group ):
35+ """
36+ Recursively convert an HDF5 group or dataset into a nested dictionary.
37+ """
38+ result = {}
39+ for key , item in h5_group .items ():
40+ if isinstance (item , h5py .Group ): # If it's a group, recurse
41+ result [key ] = h5_to_dict (item )
42+ elif isinstance (item , h5py .Dataset ): # If it's a dataset, load it
43+ result [key ] = item [:]
44+ return result
45+
46+ with h5py .File ("./robot_trajectories.h5" , "r" ) as f :
47+ data = h5_to_dict (f )
48+
49+ starting_joint_position = data ["trajectory_0" ]["position" ][0 ]
50+ current_joint_position = starting_joint_position
51+
52+ # load the trained model
53+ with open ("flax_apply_method.bin" , "rb" ) as f :
54+ serialized_from_file = f .read ()
55+ model = export .deserialize (serialized_from_file )
56+ dynamics_state = jnp .zeros ((1 , 5000 ))
57+
58+ # instantiate the task
59+ env = LasaDrawEnv (viewer = True , cfg = COLOR_SEPARATING_CONFIG )
60+ _ , _ , _ , obs = env .reset (current_joint_position )
61+
62+ # draw the demonstration trajectory
63+ # Leverage demonstrations from LASA dataset
64+ import pyLasaDataset as lasa
65+ s_data = lasa .DataSet .Sshape
66+ demos = s_data .demos
67+
68+ def preprocess_demo (demo_data ):
69+ pos = demo_data .pos
70+ vel = demo_data .vel
71+
72+ # scale position data
73+ pos_scaled = (pos / 200 ) + 0.2
74+ pos_scaled = pos_scaled [:,::4 ]
75+ positions = np .vstack ([pos_scaled [1 ,:] + 0.2 , - pos_scaled [0 ,:] + 0.2 , np .repeat (0.55 , pos_scaled .shape [1 ])]).T
76+
77+ # scale velocity data
78+ vel_scaled = (vel / 800 )
79+ vel_scaled = vel_scaled [:,::4 ]
80+ velocities = np .vstack ([vel_scaled [1 ,:], vel_scaled [0 ,:], np .repeat (0.0 , pos_scaled .shape [1 ])]).T
81+
82+ return positions , velocities
83+
84+ positions , velocities = preprocess_demo (demos [0 ])
85+
86+ for target_position in positions :
87+ with env .passive_view .lock ():
88+ env .passive_view .user_scn .ngeom += 1
89+ mujoco .mjv_initGeom (
90+ env .passive_view .user_scn .geoms [env .passive_view .user_scn .ngeom - 1 ],
91+ type = mujoco .mjtGeom .mjGEOM_SPHERE ,
92+ size = [0.001 , 0 , 0 ],
93+ pos = target_position ,
94+ mat = np .eye (3 ).flatten (),
95+ rgba = [1 , 0 , 0 , 1 ]
96+ )
97+ env .passive_view .sync ()
98+
99+ while True :
100+ # make a prediction using the model
101+ position_target , dynamics_state = model .call (jnp .expand_dims (current_joint_position , axis = 0 ), dynamics_state )
102+
103+ # pass the target to position actuators
104+ current_joint_position = env .move_to_joint_position_target (position_target [0 , :7 ])
105+
106+ # get eef xpos
107+ eef_pos = env ._robot .arm_controller .current_eef_position - np .array ([0.0 , 0.0 , 0.1 ])
108+
109+ # draw the current state
110+ with env .passive_view .lock ():
111+ env .passive_view .user_scn .ngeom += 1
112+ mujoco .mjv_initGeom (
113+ env .passive_view .user_scn .geoms [env .passive_view .user_scn .ngeom - 1 ],
114+ type = mujoco .mjtGeom .mjGEOM_SPHERE ,
115+ size = [0.001 , 0 , 0 ],
116+ pos = eef_pos ,
117+ mat = np .eye (3 ).flatten (),
118+ rgba = [0 , 1 , 0 , 1 ]
119+ )
120+ env .passive_view .sync ()
121+
122+ env .close ()
123+
0 commit comments