|
302 | 302 | " wandb.log({f\"{mode}/info_table\": table}, commit=commit)\n", |
303 | 303 | " \n", |
304 | 304 | "def log_figure_from_history(info: list,\n", |
305 | | - " episode: int,\n", |
306 | | - " tracking: Literal[\"wandb\"], # only wandb implemented so far\n", |
307 | | - " mode: Literal[\"train\", \"val\", \"test\"],\n", |
| 305 | + " episode: int,\n", |
| 306 | + " tracking: Literal[\"wandb\"], # only wandb implemented so far\n", |
| 307 | + " mode: Literal[\"train\", \"val\", \"test\"],\n", |
| 308 | + " env: BaseEnvironment,\n", |
308 | 309 | " commit: bool = True\n", |
309 | | - " ):\n", |
| 310 | + " ):\n", |
310 | 311 | " if tracking == \"wandb\":\n", |
311 | 312 | " # Plot reward and true reward over time\n", |
312 | 313 | " plt.figure(figsize=(10, 6))\n", |
|
323 | 324 | " # Plot action over time\n", |
324 | 325 | " plt.figure(figsize=(10, 6))\n", |
325 | 326 | " sns.lineplot(x=list(range(len(info))), y=[row[\"action\"] for row in info], label=\"Action\")\n", |
| 327 | + " plt.ylim(env.action_space.low[0], env.action_space.high[0]) # Set y-limits based on action space\n", |
326 | 328 | " plt.title(\"Action over time\")\n", |
327 | 329 | " plt.xlabel(\"T\")\n", |
328 | 330 | " plt.ylabel(\"Action\")\n", |
|
438 | 440 | " mode = env.mode\n", |
439 | 441 | " wandb.log({f\"{mode}/Episode\":episode,f\"{mode}/R\": R, f\"{mode}/J\": J}, commit=False)\n", |
440 | 442 | " log_info_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, commit=False)\n", |
441 | | - " log_figure_from_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, commit=True)\n", |
| 443 | + " log_figure_from_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, env, commit=True)\n", |
442 | 444 | " if return_dataset:\n", |
443 | 445 | " return np.mean(list_R), np.mean(list_J), dataset\n", |
444 | 446 | " else:\n", |
|
592 | 594 | " J_list.append(J)\n", |
593 | 595 | " wandb.log({f\"test/R\": R, f\"test/J\": J}, commit=False)\n", |
594 | 596 | " log_info_history(env.get_info(), episode, tracking, \"test\", commit=False)\n", |
595 | | - " log_figure_from_history(env.get_info(), episode, tracking, \"test\", commit=True)\n", |
| 597 | + " log_figure_from_history(env.get_info(), episode, tracking, \"test\", env, commit=True)\n", |
596 | 598 | " if ((episode+1) % print_freq) == 0:\n", |
597 | 599 | " logging.info(f\"Episode {episode+1}: R={R}, J={J}\")\n", |
598 | 600 | " elif agent.train_mode == \"pretrained\":\n", |
|
651 | 653 | " sys.stdout.flush()\n", |
652 | 654 | " wandb.log({f\"test/R\": R, f\"test/J\": J}, commit=False)\n", |
653 | 655 | " log_info_history([ep[1]for ep in episode_dataset], episode, tracking, \"test\", commit=False)\n", |
654 | | - " log_figure_from_history([ep[1]for ep in episode_dataset], episode, tracking, \"test\", commit=True)\n", |
| 656 | + " log_figure_from_history([ep[1]for ep in episode_dataset], episode, tracking, \"test\", env, commit=True)\n", |
655 | 657 | " dataset.append(episode_dataset)\n", |
656 | 658 | " \n", |
657 | 659 | " elif agent.train_mode == \"env_interaction\":\n", |
|
0 commit comments