Skip to content

Commit ed8f46a

Browse files
committed
⚡ update add contact support for gmr pipline
1 parent e6aaa55 commit ed8f46a

File tree

8 files changed

+280
-159
lines changed

8 files changed

+280
-159
lines changed

docs/guide/quick-start.md

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ This will:
2121
3. Display real-time visualization in MuJoCo viewer
2222
4. Save optimized trajectory and video
2323

24-
## Full Workflow Example
24+
## Full Workflow Example For Dexterous Hand
2525

2626
To process a task from scratch, follow these steps:
2727

@@ -40,7 +40,7 @@ export DATASET_NAME=gigahand
4040
Convert raw human motion data to standardized format:
4141

4242
```bash
43-
uv run spider/process_datasets/oakink.py \
43+
uv run spider/process_datasets/${DATASET_NAME}.py \
4444
--task=${TASK} \
4545
--embodiment-type=${HAND_TYPE} \
4646
--data-id=${DATA_ID}
@@ -103,7 +103,52 @@ Optimize trajectory with physics constraints:
103103

104104
```bash
105105
uv run examples/run_mjwp.py \
106-
+override=${DATASET_NAME} \
106+
task=${TASK} \
107+
dataset_name=${DATASET_NAME} \
108+
data_id=${DATA_ID} \
109+
robot_type=${ROBOT_TYPE} \
110+
embodiment_type=${HAND_TYPE}
111+
```
112+
113+
## Full Workflow Example For Humanoid Robot
114+
115+
Set environment variables:
116+
117+
```bash
118+
export TASK=dance
119+
export HAND_TYPE=humanoid
120+
export DATA_ID=0
121+
export ROBOT_TYPE=unitree_g1
122+
export DATASET_NAME=lafan
123+
```
124+
125+
### Run IK
126+
127+
```bash
128+
# with GMR (remember to generate GMR data trajectoru_gmr.pkl first with their official code. )
129+
uv run spider/process_datasets/gmr.py \
130+
--task=${TASK} \
131+
--dataset-name=${DATASET_NAME} \
132+
--data-id=${DATA_ID} \
133+
--robot-type=${ROBOT_TYPE} \
134+
--embodiment-type=${HAND_TYPE} \
135+
--contact-detection-mode=one
136+
137+
# with locomujoco
138+
uv run spider/process_datasets/locomujoco.py \
139+
--task=${TASK} \
140+
--dataset-name=${DATASET_NAME} \
141+
--data-id=${DATA_ID} \
142+
--robot-type=${ROBOT_TYPE} \
143+
--embodiment-type=${HAND_TYPE}
144+
```
145+
146+
### Run Physics-Based Retargeting
147+
148+
```bash
149+
uv run examples/run_mjwp.py \
150+
+override=humanoid \
151+
dataset_name=${DATASET_NAME} \
107152
task=${TASK} \
108153
data_id=${DATA_ID} \
109154
robot_type=${ROBOT_TYPE} \

examples/config/override/humanoid.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ joint_noise_scale: 0.1
2020
# sampling
2121
improvement_threshold: 0.01
2222
max_num_iterations: 16
23-
temperature: 0.1
23+
temperature: 0.6
2424
num_samples: 2048
2525
# reward
2626
vel_rew_scale: 0.0
2727
joint_rew_scale: 1.0
2828
pos_rew_scale: 3.0
2929
rot_rew_scale: 3.0
30+
contact_rew_scale: 1.0
3031
# note: above parameter is optimized for speed. you can have better performance with horizon=1.0

spider/assets/robots/unitree_g1/scene.xml

Lines changed: 66 additions & 83 deletions
Large diffs are not rendered by default.

spider/assets/robots/unitree_g1/scene_simple.xml renamed to spider/assets/robots/unitree_g1/scene_soft.xml

Lines changed: 83 additions & 62 deletions
Large diffs are not rendered by default.

spider/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class Config:
9797
rot_rew_scale: float = 0.1
9898
vel_rew_scale: float = 0.0001
9999
terminal_rew_scale: float = 1.0
100+
contact_rew_scale: float = 0.0
100101

101102
# === VISUALIZATION CONFIGURATION ===
102103
show_viewer: bool = True
@@ -261,4 +262,16 @@ def process_config(config: Config):
261262
config.ref_dt = task_info["ref_dt"]
262263
loguru.logger.info(f"overriding ref_dt: {config.ref_dt} from task_info.json")
263264

265+
# override contact site ids
266+
if config.contact_rew_scale > 0.0:
267+
if "contact_site_ids" in task_info:
268+
config.contact_site_ids = task_info["contact_site_ids"]
269+
loguru.logger.info(
270+
f"overriding contact_site_ids: {config.contact_site_ids} from task_info.json"
271+
)
272+
else:
273+
raise ValueError(
274+
"contact_site_ids not found in task_info.json while contact_rew_scale > 0.0"
275+
)
276+
264277
return config

spider/io.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def load_data(
3333
raw_data = np.load(data_path)
3434
qpos_ref = raw_data["qpos"]
3535
qvel_ref = raw_data["qvel"]
36+
# if contact_rew_scale > 0.0, we need to make sure both contact and contact_pos are provided.
37+
if config.contact_rew_scale > 0.0:
38+
if "contact" not in raw_data:
39+
raise ValueError("contact data not found while contact_rew_scale > 0.0")
40+
if "contact_pos" not in raw_data:
41+
raise ValueError("contact_pos data not found while contact_rew_scale > 0.0")
3642
try:
3743
contact = raw_data["contact"]
3844
except:

spider/process_datasets/gmr.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,28 @@ def main(
5757
enable_rate_limiter: bool = False,
5858
start_frame: int = 0,
5959
end_frame: int = -1,
60+
contact_detection_mode: str = "one",
6061
):
62+
"""
63+
Process GMR data to create a SPIDER dataset.
64+
Args:
65+
dataset_dir: The directory containing the GMR data.
66+
dataset_name: The name of the dataset.
67+
robot_type: The type of robot.
68+
embodiment_type: The type of embodiment.
69+
task: The task to perform.
70+
data_id: The id of the data.
71+
show_viewer: Whether to show the viewer.
72+
save_video: Whether to save the video.
73+
overwrite: Whether to overwrite the existing data.
74+
enable_rate_limiter: Whether to enable the rate limiter.
75+
start_frame: The start frame of the data.
76+
end_frame: The end frame of the data.
77+
contact_detection_mode: The mode of contact detection.
78+
"auto": Automatically detect contact based on mujoco contact detection.
79+
"zero": Always disable contact.
80+
"one": Always enable contact.
81+
"""
6182
dataset_dir = os.path.abspath(dataset_dir)
6283
processed_dir = get_processed_data_dir(
6384
dataset_dir=dataset_dir,
@@ -100,17 +121,26 @@ def main(
100121
shutil.copy(src_scene_file, tgt_scene_file)
101122
print(f"copy from {src_scene_file} to {tgt_scene_file}")
102123

103-
# create task info file
104-
task_info_file = f"{scene_dir}/task_info.json"
105-
with open(task_info_file, "w") as f:
106-
json.dump({"ref_dt": 1.0 / fps}, f, indent=2)
107-
print(f"Saved task info to {task_info_file}")
108-
109124
# run mujoco
110125
mj_model = mujoco.MjModel.from_xml_path(tgt_scene_file)
111126
mj_data = mujoco.MjData(mj_model)
112127
run_viewer = get_viewer(show_viewer, mj_model, mj_data)
113128
rate_limiter = RateLimiter(fps)
129+
130+
# contact site id
131+
contact_site_ids = []
132+
for i in range(mj_model.nsite):
133+
site_name = mujoco.mj_id2name(mj_model, mujoco.mjtObj.mjOBJ_SITE, i)
134+
if site_name and "contact" in site_name:
135+
contact_site_ids.append(i)
136+
assert len(contact_site_ids) > 0 and contact_detection_mode != "zero", "No contact site found while you enable contact detection"
137+
138+
# create task info file
139+
task_info_file = f"{scene_dir}/task_info.json"
140+
with open(task_info_file, "w") as f:
141+
json.dump({"ref_dt": 1.0 / fps, "contact_site_ids": contact_site_ids}, f, indent=2)
142+
print(f"Saved task info to {task_info_file}")
143+
114144
# log info
115145
info_list = []
116146
# log video
@@ -129,17 +159,27 @@ def main(
129159
)
130160
else:
131161
mj_data.qvel[:] = 0.0
132-
# compute contact (currently it is a placeholder, will be implemented later)
133-
contact = np.zeros(1)
162+
134163
# compute ctrl
135164
mj_data.ctrl[:] = qpos[i][7:]
136165
mujoco.mj_forward(mj_model, mj_data)
166+
167+
# compute contact
168+
contact_pos = mj_data.site_xpos[contact_site_ids, :]
169+
if contact_detection_mode == "one":
170+
contact = np.ones(len(contact_site_ids))
171+
elif contact_detection_mode == "zero":
172+
contact = np.zeros(len(contact_site_ids))
173+
else:
174+
contact = contact_pos[:, 2] < 0.001
175+
137176
# log
138177
info = {
139178
"qpos": mj_data.qpos.copy(),
140179
"qvel": mj_data.qvel.copy(),
141180
"ctrl": mj_data.ctrl.copy(),
142181
"contact": contact,
182+
"contact_pos": contact_pos.copy(),
143183
}
144184
info_list.append(info)
145185
# render

spider/simulators/mjwp.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,10 @@ def get_reward(
270270
271271
TODO: move reward computation to task-specific module
272272
"""
273-
qpos_ref, qvel_ref, ctrl_ref, contact_ref, _contact_pos_ref = ref
273+
qpos_ref, qvel_ref, ctrl_ref, contact_ref, contact_pos_ref = ref
274274
qpos_sim = wp.to_torch(env.data_wp.qpos)
275275
qvel_sim = wp.to_torch(env.data_wp.qvel)
276+
276277
# weighted qpos tracking
277278
qpos_diff = _diff_qpos(
278279
config, qpos_sim, qpos_ref.unsqueeze(0).repeat(qpos_sim.shape[0], 1)
@@ -284,7 +285,18 @@ def get_reward(
284285

285286
qpos_rew = -qpos_dist * 1.0
286287
qvel_rew = -config.vel_rew_scale * qvel_dist * 1.0
287-
reward = qpos_rew + qvel_rew
288+
289+
# contact reward
290+
if config.contact_rew_scale > 0.0 and len(config.contact_site_ids) > 0:
291+
site_xpos_torch = wp.to_torch(env.data_wp.site_xpos)
292+
contact_pos = site_xpos_torch[:, config.contact_site_ids]
293+
contact_dist = torch.norm(contact_pos - contact_pos_ref, p=2, dim=-1)
294+
contact_dist_masked = contact_dist * contact_ref.unsqueeze(0)
295+
contact_rew = -contact_dist_masked.sum(dim=1)
296+
else:
297+
contact_rew = 0.0
298+
299+
reward = qpos_rew + qvel_rew + contact_rew
288300

289301
info = {
290302
"qpos_dist": qpos_dist,

0 commit comments

Comments
 (0)