|
172 | 172 | "ep_length_mean, ep_length_std = [], []\n", |
173 | 173 | "times = [datetime.now()]\n", |
174 | 174 | "\n", |
175 | | - "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", |
176 | | - "\n", |
177 | 175 | "\n", |
178 | 176 | "def progress(num_steps, metrics):\n", |
179 | 177 | " # Log to wandb.\n", |
|
186 | 184 | " x_data.append(num_steps)\n", |
187 | 185 | " y_data.append(metrics[\"eval/episode_reward\"])\n", |
188 | 186 | " y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n", |
189 | | - " ep_length_mean.append(metrics[\"eval/avg_episode_length\"])\n", |
190 | | - " ep_length_std.append(metrics[\"eval/avg_episode_length_std\"])\n", |
191 | | - "\n", |
192 | | - " axes[0].set_xlim([0, ppo_params.num_timesteps * 1.25])\n", |
193 | | - " axes[0].set_xlabel(\"# environment steps\")\n", |
194 | | - " axes[0].set_ylabel(\"reward per episode\")\n", |
195 | | - " axes[0].set_title(f\"y={y_data[-1]:.3f}\")\n", |
196 | | - " axes[0].errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n", |
197 | 187 | "\n", |
198 | | - " axes[1].set_xlim([0, ppo_params.num_timesteps * 1.25])\n", |
199 | | - " axes[1].set_xlabel(\"# environment steps\")\n", |
200 | | - " axes[1].set_ylabel(\"episode length\")\n", |
201 | | - " axes[1].set_title(f\"y={ep_length_mean[-1]:.3f}\")\n", |
202 | | - " axes[1].errorbar(x_data, ep_length_mean, yerr=ep_length_std, color=\"blue\")\n", |
| 188 | + " plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n", |
| 189 | + " plt.xlabel(\"# environment steps\")\n", |
| 190 | + " plt.ylabel(\"reward per episode\")\n", |
| 191 | + " plt.title(f\"y={y_data[-1]:.3f}\")\n", |
| 192 | + " plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n", |
203 | 193 | "\n", |
204 | 194 | " display(plt.gcf())\n", |
205 | 195 | "\n", |
|
357 | 347 | " rewards.append(\n", |
358 | 348 | " {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n", |
359 | 349 | " )\n", |
360 | | - " linvel.append(eval_env.get_local_linvel(state.data))\n", |
| 350 | + " linvel.append(eval_env.get_local_linvel(state.data, \"pelvis\"))\n", |
361 | 351 | " angvel.append(eval_env.get_gyro(state.data))\n", |
362 | 352 | " track.append(\n", |
363 | 353 | " eval_env._reward_tracking_lin_vel(\n", |
364 | | - " state.info[\"command\"], eval_env.get_local_linvel(state.data)\n", |
| 354 | + " state.info[\"command\"], eval_env.get_local_linvel(state.data, \"pelvis\")\n", |
365 | 355 | " )\n", |
366 | 356 | " )\n", |
367 | 357 | "\n", |
|
381 | 371 | " qvels.append(qvel)\n", |
382 | 372 | " qpos_cost.append(jp.sum(jp.square(state.data.qpos[7:] - eval_env._default_pose)))\n", |
383 | 373 | "\n", |
384 | | - " xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso\").id])\n", |
| 374 | + " xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso_link\").id])\n", |
385 | 375 | " xyz += np.array([0, 0.0, 0])\n", |
386 | 376 | " x_axis = state.data.xmat[eval_env._torso_body_id, 0]\n", |
387 | 377 | " yaw = -np.arctan2(x_axis[1], x_axis[0])\n", |
|
0 commit comments