Skip to content

Commit 9811486

Browse files
Merge pull request #159 from s1lent4gnt:lilkm/fix-humanoid_joystick-nb
PiperOrigin-RevId: 794777571 Change-Id: Ia069206a21e563111c33ae2198e6b41e0a942ce4
2 parents 46a951d + 81c1a36 commit 9811486

File tree

1 file changed

+8
-18
lines changed

1 file changed

+8
-18
lines changed

mujoco_playground/experimental/learning/humanoid_joystick.ipynb

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@
172172
"ep_length_mean, ep_length_std = [], []\n",
173173
"times = [datetime.now()]\n",
174174
"\n",
175-
"fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
176-
"\n",
177175
"\n",
178176
"def progress(num_steps, metrics):\n",
179177
" # Log to wandb.\n",
@@ -186,20 +184,12 @@
186184
" x_data.append(num_steps)\n",
187185
" y_data.append(metrics[\"eval/episode_reward\"])\n",
188186
" y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
189-
" ep_length_mean.append(metrics[\"eval/avg_episode_length\"])\n",
190-
" ep_length_std.append(metrics[\"eval/avg_episode_length_std\"])\n",
191-
"\n",
192-
" axes[0].set_xlim([0, ppo_params.num_timesteps * 1.25])\n",
193-
" axes[0].set_xlabel(\"# environment steps\")\n",
194-
" axes[0].set_ylabel(\"reward per episode\")\n",
195-
" axes[0].set_title(f\"y={y_data[-1]:.3f}\")\n",
196-
" axes[0].errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
197187
"\n",
198-
" axes[1].set_xlim([0, ppo_params.num_timesteps * 1.25])\n",
199-
" axes[1].set_xlabel(\"# environment steps\")\n",
200-
" axes[1].set_ylabel(\"episode length\")\n",
201-
" axes[1].set_title(f\"y={ep_length_mean[-1]:.3f}\")\n",
202-
" axes[1].errorbar(x_data, ep_length_mean, yerr=ep_length_std, color=\"blue\")\n",
188+
" plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
189+
" plt.xlabel(\"# environment steps\")\n",
190+
" plt.ylabel(\"reward per episode\")\n",
191+
" plt.title(f\"y={y_data[-1]:.3f}\")\n",
192+
" plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
203193
"\n",
204194
" display(plt.gcf())\n",
205195
"\n",
@@ -357,11 +347,11 @@
357347
" rewards.append(\n",
358348
" {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
359349
" )\n",
360-
" linvel.append(eval_env.get_local_linvel(state.data))\n",
350+
" linvel.append(eval_env.get_local_linvel(state.data, \"pelvis\"))\n",
361351
" angvel.append(eval_env.get_gyro(state.data))\n",
362352
" track.append(\n",
363353
" eval_env._reward_tracking_lin_vel(\n",
364-
" state.info[\"command\"], eval_env.get_local_linvel(state.data)\n",
354+
" state.info[\"command\"], eval_env.get_local_linvel(state.data, \"pelvis\")\n",
365355
" )\n",
366356
" )\n",
367357
"\n",
@@ -381,7 +371,7 @@
381371
" qvels.append(qvel)\n",
382372
" qpos_cost.append(jp.sum(jp.square(state.data.qpos[7:] - eval_env._default_pose)))\n",
383373
"\n",
384-
" xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso\").id])\n",
374+
" xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso_link\").id])\n",
385375
" xyz += np.array([0, 0.0, 0])\n",
386376
" x_axis = state.data.xmat[eval_env._torso_body_id, 0]\n",
387377
" yaw = -np.arctan2(x_axis[1], x_axis[0])\n",

0 commit comments

Comments
 (0)