Skip to content

Commit e999f81

Browse files
committed
Updated SB3 quad training script
Weight and biaises integration. New progress bar for training. New callbacks : WandB, HParam, Evaluation * WandB : Integrate wandb visualization from tensorboard scalar metrics, saves model and code. * HParam : provides structured visualization of parameters. * Evaluation : The model is evaluated every n steps. Early stopping, if no improvement is observed over time, training is stopped automatically.
1 parent 1865d67 commit e999f81

File tree

4 files changed

+133
-23
lines changed

4 files changed

+133
-23
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"env_name": "Quadmesh-v0",
3+
"mesh_size": 16,
4+
"max_episode_steps": 20,
5+
"n_darts_selected": 10,
6+
"deep": 8,
7+
"action_restriction": false,
8+
"with_degree_observation": false
9+
}

environment/gymnasium_envs/quadmesh_env/envs/mesh_conv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def get_template(mesh: Mesh, deep: int, nodes_scores):
8080
template[n_darts - 1, len(E)-1] = nodes_scores[N2.id]
8181
else:
8282
E.extend([None,None])
83+
#template[n_darts - 1, len(E) - 1] = -500 # dummy vertices are assigned to -500
84+
#template[n_darts - 1, len(E) - 2] = -500 # dummy vertices are assigned to -500
8385

8486
template = template[:n_darts, :]
8587

mesh_model/random_quadmesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def random_mesh() -> Mesh:
1919
#filename = os.path.join('../mesh_files/', 't1_quad.msh')
2020
#mesh = read_gmsh("/home/ropercha/PycharmProjects/tune/mesh_files/t1_quad.msh")
2121
mesh = read_gmsh(filename)
22-
mesh_shuffle(mesh, 10)
22+
mesh_shuffle(mesh, 5)
2323
return mesh
2424

2525
def mesh_shuffle(mesh: Mesh, num_nodes) -> Mesh:

training/train_quadmesh_SB3.py

Lines changed: 121 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,65 @@
22

33
import os
44
import json
5+
import matplotlib.pyplot as plt
6+
from sphinx.util import os_path
57

68
import mesh_model.random_quadmesh as QM
79
from mesh_model.reader import read_gmsh
810
from view.mesh_plotter.mesh_plots import dataset_plt
911
from training.exploit_SB3_policy import testPolicy
1012
from stable_baselines3 import PPO
1113
from stable_baselines3.common.env_checker import check_env
12-
from stable_baselines3.common.callbacks import BaseCallback
13-
from stable_baselines3.common.logger import Figure
14+
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, StopTrainingOnNoModelImprovement, ProgressBarCallback
15+
from stable_baselines3.common.logger import Figure, HParam
16+
import wandb
17+
from wandb.integration.sb3 import WandbCallback
1418

1519
from environment.gymnasium_envs import quadmesh_env
1620

1721
import gymnasium as gym
1822
import random
1923
import numpy as np
2024
import torch
25+
import os
26+
import tqdm
27+
import rich
28+
29+
class HParamCallback(BaseCallback):
30+
"""
31+
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
32+
"""
33+
34+
def _on_training_start(self) -> None:
35+
hparam_dict = {
36+
"algorithm": self.model.__class__.__name__,
37+
"experiment": experiment_name,
38+
"learning rate": self.model.learning_rate,
39+
"gamma": self.model.gamma,
40+
"batch_size": ppo_config["batch_size"],
41+
"epochs": ppo_config["n_epochs"],
42+
"training_meshes": training_mesh_file_path,
43+
"evaluation_meshes": evaluation_mesh_file_path,
44+
"max_steps": env_config["max_episode_steps"],
45+
"max_timesteps": ppo_config["total_timesteps"],
46+
47+
48+
}
49+
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
50+
# Tensorbaord will find & display metrics from the `SCALARS` tab
51+
metric_dict = {
52+
"normalized_return": 0,
53+
"rollout/ep_len_mean": 0.0,
54+
"rollout/ep_rew_mean": 0.0
55+
}
56+
self.logger.record(
57+
"hparams",
58+
HParam(hparam_dict, metric_dict),
59+
exclude=("stdout", "log", "json", "csv"),
60+
)
61+
62+
def _on_step(self) -> bool:
63+
return True
2164

2265
class TensorboardCallback(BaseCallback):
2366
"""
@@ -54,7 +97,7 @@ def _on_training_start(self) -> None:
5497
"""
5598
self.logger.record("parameters/ppo", f"<pre>{json.dumps(ppo_config, indent=4)}</pre>")
5699
self.logger.record("parameters/env", f"<pre>{json.dumps(env_config, indent=4)}</pre>")
57-
self.logger.dump(step=0)
100+
58101

59102
def _on_step(self) -> bool:
60103
"""
@@ -90,17 +133,17 @@ def _on_step(self) -> bool:
90133
self.final_distance = self.locals["infos"][0].get("distance", 0.0)
91134
self.logger.record("final_distance", self.final_distance)
92135
self.logger.record("valid_actions", self.actions_info["episode_valid_actions"]*100/self.current_episode_length if self.current_episode_length > 0 else 0)
93-
self.logger.record("n_invalid_topo", self.actions_info["episode_invalid_topo"])
94-
self.logger.record("n_invalid_geo", self.actions_info["episode_invalid_geo"])
95-
self.logger.record("nb_flip_cw", self.actions_info["nb_flip_cw"])
96-
self.logger.record("nb_flip_cntcw", self.actions_info["nb_flip_cntcw"])
97-
self.logger.record("nb_split", self.actions_info["nb_split"])
98-
self.logger.record("nb_collapse", self.actions_info["nb_collapse"])
99-
self.logger.record("nb_cleanup", self.actions_info["nb_cleanup"])
100-
self.logger.record("invalid_flip", self.actions_info["nb_invalid_flip"]*100/(self.actions_info["nb_flip_cw"]+self.actions_info["nb_flip_cntcw"]) if (self.actions_info["nb_flip_cw"]+self.actions_info["nb_flip_cntcw"]) > 0 else 0)
101-
self.logger.record("invalid_split", self.actions_info["nb_invalid_split"]*100/self.actions_info["nb_split"] if self.actions_info["nb_split"] > 0 else 0)
102-
self.logger.record("invalid_collapse", self.actions_info["nb_invalid_collapse"]*100/self.actions_info["nb_collapse"]if self.actions_info["nb_collapse"] > 0 else 0)
103-
self.logger.record("invalid_cleanup", self.actions_info["nb_invalid_cleanup"]*100/self.actions_info["nb_cleanup"]if self.actions_info["nb_cleanup"] > 0 else 0)
136+
self.logger.record("actions/n_invalid_topo", self.actions_info["episode_invalid_topo"])
137+
self.logger.record("actions/n_invalid_geo", self.actions_info["episode_invalid_geo"])
138+
self.logger.record("actions/nb_flip_cw", self.actions_info["nb_flip_cw"])
139+
self.logger.record("actions/nb_flip_cntcw", self.actions_info["nb_flip_cntcw"])
140+
self.logger.record("actions/nb_split", self.actions_info["nb_split"])
141+
self.logger.record("actions/nb_collapse", self.actions_info["nb_collapse"])
142+
self.logger.record("actions/nb_cleanup", self.actions_info["nb_cleanup"])
143+
self.logger.record("actions/invalid_flip", self.actions_info["nb_invalid_flip"]*100/(self.actions_info["nb_flip_cw"]+self.actions_info["nb_flip_cntcw"]) if (self.actions_info["nb_flip_cw"]+self.actions_info["nb_flip_cntcw"]) > 0 else 0)
144+
self.logger.record("actions/invalid_split", self.actions_info["nb_invalid_split"]*100/self.actions_info["nb_split"] if self.actions_info["nb_split"] > 0 else 0)
145+
self.logger.record("actions/invalid_collapse", self.actions_info["nb_invalid_collapse"]*100/self.actions_info["nb_collapse"]if self.actions_info["nb_collapse"] > 0 else 0)
146+
self.logger.record("actions/invalid_cleanup", self.actions_info["nb_invalid_cleanup"]*100/self.actions_info["nb_cleanup"]if self.actions_info["nb_cleanup"] > 0 else 0)
104147
self.logger.record("episode_mesh_reward", self.mesh_reward)
105148
self.logger.record("episode_reward", self.current_episode_reward)
106149
self.logger.record("normalized_return", self.normalized_return)
@@ -123,42 +166,97 @@ def _on_step(self) -> bool:
123166
def _on_training_end(self) -> None:
124167
"""
125168
Records policy evaluation results : before and after dataset images
169+
Save registry counts of observation in a csv file. Records analysis
126170
"""
127-
filename = "counts_PPO47.json"
171+
filename = "counts_" + experiment_name + ".json"
128172
counts_registry = self.locals["infos"][0].get("observation_count", 0.0)
129173
counts = counts_registry.counts
130174

131175
# Convertir les clés tuple en chaînes de caractères
132-
counts_str_keys = {v: str(k) for k, v in counts.items()}
176+
counts_str_keys = [(v, str(k)) for k, v in counts.items()]
177+
counts_values = list(counts.values())
133178

134179
# Écriture dans un fichier JSON
135180
with open(filename, "w") as file:
136181
json.dump(counts_str_keys, file, indent=4)
137182

138183
print(f"Counts saved at {filename}")
139184

185+
self.logger.record("observation/n_observation", len(counts_values))
186+
self.logger.record("observation/mean", np.mean(counts_values))
187+
self.logger.record("observation/median", np.median(counts_values))
188+
self.logger.record("observation/min", np.min(counts_values))
189+
self.logger.record("observation/max", np.max(counts_values))
190+
191+
counts_values.sort()
192+
figure, ax = plt.subplots()
193+
ax.hist(counts_values, bins='auto')
194+
ax.set_title("Observation counts")
195+
# Close the figure after logging it
196+
self.logger.record("observation/counts", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv"))
197+
plt.close()
198+
199+
140200
#mesh = read_gmsh("mesh_files/medium_quad.msh")
141201
dataset = [QM.random_mesh() for _ in range(9)] # dataset of 9 meshes of size 30
142202
before = dataset_plt(dataset) # plot the datasat as image
143203
length, wins, rewards, normalized_return, final_meshes = testPolicy(self.model, 10, env_config, dataset) # test model policy on the dataset
144204
after = dataset_plt(final_meshes)
145205
self.logger.record("figures/before", Figure(before, close=True), exclude=("stdout", "log"))
146206
self.logger.record("figures/after", Figure(after, close=True), exclude=("stdout", "log"))
147-
self.logger.dump(step=0)
207+
self.logger.dump(step=self.num_timesteps)
148208

149209
if __name__ == '__main__':
150210

211+
experiment_name = "wandb_test"
212+
ppo_config_path = "model_RL/parameters/ppo_config.json"
213+
env_config_path = "environment/environment_config.json"
214+
eval_env_config_path = "environment/eval_environment_config.json"
215+
policy_saving_path = os.path.join("training/policy_saved/quad/", experiment_name)
216+
wandb_model_save_path = f"training/wandb_models/{experiment_name}"
217+
218+
#Mesh datasets
219+
evaluation_mesh_file_path = "mesh_files/simple_quad.msh"
220+
training_mesh_file_path = "mesh_files/simple_quad.msh"
221+
222+
151223
# SEEDING
152224
seed = 1
153225
random.seed(seed)
154226
np.random.seed(seed)
155227
torch.manual_seed(seed)
156228
torch.backends.cudnn.deterministic = True
157229

158-
with open("model_RL/parameters/ppo_config.json", "r") as f:
230+
# PARAMETERS CONFIGURATION
231+
232+
with open(ppo_config_path, "r") as f:
159233
ppo_config = json.load(f)
160-
with open("environment/environment_config.json", "r") as f:
234+
with open(env_config_path, "r") as f:
161235
env_config = json.load(f)
236+
with open(eval_env_config_path, "r") as f:
237+
eval_env_config = json.load(f)
238+
239+
# WANDB
240+
run = wandb.init(
241+
project="sb3",
242+
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
243+
save_code=True, # optional
244+
)
245+
# EVALUATION CALLBACKS
246+
247+
# Separate evaluation env
248+
eval_env = gym.make(
249+
eval_env_config["env_name"],
250+
mesh = read_gmsh(evaluation_mesh_file_path),
251+
max_episode_steps=eval_env_config["max_episode_steps"],
252+
n_darts_selected=eval_env_config["n_darts_selected"],
253+
deep= eval_env_config["deep"],
254+
action_restriction=eval_env_config["action_restriction"],
255+
with_degree_obs=eval_env_config["with_degree_observation"]
256+
)
257+
# Stop training if there is no improvement after more than 3 evaluations
258+
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=5, min_evals=5, verbose=1)
259+
eval_callback = EvalCallback(eval_env, eval_freq=500, callback_after_eval=stop_train_callback, verbose=1)
162260

163261
# Create log dir
164262
log_dir = ppo_config["tensorboard_log"]
@@ -167,7 +265,7 @@ def _on_training_end(self) -> None:
167265
# Create the environment
168266
env = gym.make(
169267
env_config["env_name"],
170-
#mesh = read_gmsh("../mesh_files/medium_quad.msh"),
268+
mesh = read_gmsh(training_mesh_file_path),
171269
max_episode_steps=env_config["max_episode_steps"],
172270
n_darts_selected=env_config["n_darts_selected"],
173271
deep= env_config["deep"],
@@ -190,6 +288,7 @@ def _on_training_end(self) -> None:
190288
)
191289

192290
print("-----------Starting learning-----------")
193-
model.learn(total_timesteps=ppo_config["total_timesteps"], callback=TensorboardCallback(model))
291+
model.learn(total_timesteps=ppo_config["total_timesteps"], tb_log_name=experiment_name, callback=[HParamCallback(), WandbCallback(model_save_path=wandb_model_save_path), TensorboardCallback(model), eval_callback], progress_bar=True)
194292
print("-----------Learning ended------------")
195-
model.save("training/policy_saved/quad/4-actions-quad-rand_simple-PPO47")
293+
model.save(policy_saving_path)
294+
run.finish()

0 commit comments

Comments
 (0)