Skip to content

Commit adfe25d

Browse files
committed
added bounds to figure plots in wandb
1 parent 750d7a8 commit adfe25d

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

ddopai/experiments/experiment_functions_meta.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,12 @@ def log_info_history(info: list,
134134
wandb.log({f"{mode}/info_table": table}, commit=commit)
135135

136136
def log_figure_from_history(info: list,
137-
episode: int,
138-
tracking: Literal["wandb"], # only wandb implemented so far
139-
mode: Literal["train", "val", "test"],
137+
episode: int,
138+
tracking: Literal["wandb"], # only wandb implemented so far
139+
mode: Literal["train", "val", "test"],
140+
env: BaseEnvironment,
140141
commit: bool = True
141-
):
142+
):
142143
if tracking == "wandb":
143144
# Plot reward and true reward over time
144145
plt.figure(figsize=(10, 6))
@@ -155,6 +156,7 @@ def log_figure_from_history(info: list,
155156
# Plot action over time
156157
plt.figure(figsize=(10, 6))
157158
sns.lineplot(x=list(range(len(info))), y=[row["action"] for row in info], label="Action")
159+
plt.ylim(env.action_space.low[0], env.action_space.high[0]) # Set y-limits based on action space
158160
plt.title("Action over time")
159161
plt.xlabel("T")
160162
plt.ylabel("Action")
@@ -253,7 +255,7 @@ def test_agent(agent: BaseAgent,
253255
mode = env.mode
254256
wandb.log({f"{mode}/Episode":episode,f"{mode}/R": R, f"{mode}/J": J}, commit=False)
255257
log_info_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, commit=False)
256-
log_figure_from_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, commit=True)
258+
log_figure_from_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, env, commit=True)
257259
if return_dataset:
258260
return np.mean(list_R), np.mean(list_J), dataset
259261
else:
@@ -407,7 +409,7 @@ def run_experiment( agent: BaseAgent,
407409
J_list.append(J)
408410
wandb.log({f"test/R": R, f"test/J": J}, commit=False)
409411
log_info_history(env.get_info(), episode, tracking, "test", commit=False)
410-
log_figure_from_history(env.get_info(), episode, tracking, "test", commit=True)
412+
log_figure_from_history(env.get_info(), episode, tracking, "test", env, commit=True)
411413
if ((episode+1) % print_freq) == 0:
412414
logging.info(f"Episode {episode+1}: R={R}, J={J}")
413415
elif agent.train_mode == "pretrained":
@@ -466,7 +468,7 @@ def run_experiment( agent: BaseAgent,
466468
sys.stdout.flush()
467469
wandb.log({f"test/R": R, f"test/J": J}, commit=False)
468470
log_info_history([ep[1]for ep in episode_dataset], episode, tracking, "test", commit=False)
469-
log_figure_from_history([ep[1]for ep in episode_dataset], episode, tracking, "test", commit=True)
471+
log_figure_from_history([ep[1]for ep in episode_dataset], episode, tracking, "test", env, commit=True)
470472
dataset.append(episode_dataset)
471473

472474
elif agent.train_mode == "env_interaction":

nbs/40_experiments/10_experiment_functions_meta.ipynb

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,12 @@
302302
" wandb.log({f\"{mode}/info_table\": table}, commit=commit)\n",
303303
" \n",
304304
"def log_figure_from_history(info: list,\n",
305-
" episode: int,\n",
306-
" tracking: Literal[\"wandb\"], # only wandb implemented so far\n",
307-
" mode: Literal[\"train\", \"val\", \"test\"],\n",
305+
" episode: int,\n",
306+
" tracking: Literal[\"wandb\"], # only wandb implemented so far\n",
307+
" mode: Literal[\"train\", \"val\", \"test\"],\n",
308+
" env: BaseEnvironment,\n",
308309
" commit: bool = True\n",
309-
" ):\n",
310+
" ):\n",
310311
" if tracking == \"wandb\":\n",
311312
" # Plot reward and true reward over time\n",
312313
" plt.figure(figsize=(10, 6))\n",
@@ -323,6 +324,7 @@
323324
" # Plot action over time\n",
324325
" plt.figure(figsize=(10, 6))\n",
325326
" sns.lineplot(x=list(range(len(info))), y=[row[\"action\"] for row in info], label=\"Action\")\n",
327+
" plt.ylim(env.action_space.low[0], env.action_space.high[0]) # Set y-limits based on action space\n",
326328
" plt.title(\"Action over time\")\n",
327329
" plt.xlabel(\"T\")\n",
328330
" plt.ylabel(\"Action\")\n",
@@ -438,7 +440,7 @@
438440
" mode = env.mode\n",
439441
" wandb.log({f\"{mode}/Episode\":episode,f\"{mode}/R\": R, f\"{mode}/J\": J}, commit=False)\n",
440442
" log_info_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, commit=False)\n",
441-
" log_figure_from_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, commit=True)\n",
443+
" log_figure_from_history([ep_d[1] for ep_d in episode_dataset], episode, tracking, mode, env, commit=True)\n",
442444
" if return_dataset:\n",
443445
" return np.mean(list_R), np.mean(list_J), dataset\n",
444446
" else:\n",
@@ -592,7 +594,7 @@
592594
" J_list.append(J)\n",
593595
" wandb.log({f\"test/R\": R, f\"test/J\": J}, commit=False)\n",
594596
" log_info_history(env.get_info(), episode, tracking, \"test\", commit=False)\n",
595-
" log_figure_from_history(env.get_info(), episode, tracking, \"test\", commit=True)\n",
597+
" log_figure_from_history(env.get_info(), episode, tracking, \"test\", env, commit=True)\n",
596598
" if ((episode+1) % print_freq) == 0:\n",
597599
" logging.info(f\"Episode {episode+1}: R={R}, J={J}\")\n",
598600
" elif agent.train_mode == \"pretrained\":\n",
@@ -651,7 +653,7 @@
651653
" sys.stdout.flush()\n",
652654
" wandb.log({f\"test/R\": R, f\"test/J\": J}, commit=False)\n",
653655
" log_info_history([ep[1]for ep in episode_dataset], episode, tracking, \"test\", commit=False)\n",
654-
" log_figure_from_history([ep[1]for ep in episode_dataset], episode, tracking, \"test\", commit=True)\n",
656+
" log_figure_from_history([ep[1]for ep in episode_dataset], episode, tracking, \"test\", env, commit=True)\n",
655657
" dataset.append(episode_dataset)\n",
656658
" \n",
657659
" elif agent.train_mode == \"env_interaction\":\n",

0 commit comments

Comments
 (0)