22
33import os
44import json
5+ import matplotlib .pyplot as plt
6+ from sphinx .util import os_path
57
68import mesh_model .random_quadmesh as QM
79from mesh_model .reader import read_gmsh
810from view .mesh_plotter .mesh_plots import dataset_plt
911from training .exploit_SB3_policy import testPolicy
1012from stable_baselines3 import PPO
1113from stable_baselines3 .common .env_checker import check_env
12- from stable_baselines3 .common .callbacks import BaseCallback
13- from stable_baselines3 .common .logger import Figure
14+ from stable_baselines3 .common .callbacks import BaseCallback , EvalCallback , StopTrainingOnNoModelImprovement , ProgressBarCallback
15+ from stable_baselines3 .common .logger import Figure , HParam
16+ import wandb
17+ from wandb .integration .sb3 import WandbCallback
1418
1519from environment .gymnasium_envs import quadmesh_env
1620
1721import gymnasium as gym
1822import random
1923import numpy as np
2024import torch
25+ import os
26+ import tqdm
27+ import rich
28+
29+ class HParamCallback (BaseCallback ):
30+ """
31+ Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
32+ """
33+
34+ def _on_training_start (self ) -> None :
35+ hparam_dict = {
36+ "algorithm" : self .model .__class__ .__name__ ,
37+ "experiment" : experiment_name ,
38+ "learning rate" : self .model .learning_rate ,
39+ "gamma" : self .model .gamma ,
40+ "batch_size" : ppo_config ["batch_size" ],
41+ "epochs" : ppo_config ["n_epochs" ],
42+ "training_meshes" : training_mesh_file_path ,
43+ "evaluation_meshes" : evaluation_mesh_file_path ,
44+ "max_steps" : env_config ["max_episode_steps" ],
45+ "max_timesteps" : ppo_config ["total_timesteps" ],
46+
47+
48+ }
49+ # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
50+ # Tensorbaord will find & display metrics from the `SCALARS` tab
51+ metric_dict = {
52+ "normalized_return" : 0 ,
53+ "rollout/ep_len_mean" : 0.0 ,
54+ "rollout/ep_rew_mean" : 0.0
55+ }
56+ self .logger .record (
57+ "hparams" ,
58+ HParam (hparam_dict , metric_dict ),
59+ exclude = ("stdout" , "log" , "json" , "csv" ),
60+ )
61+
62+ def _on_step (self ) -> bool :
63+ return True
2164
2265class TensorboardCallback (BaseCallback ):
2366 """
@@ -54,7 +97,7 @@ def _on_training_start(self) -> None:
5497 """
5598 self .logger .record ("parameters/ppo" , f"<pre>{ json .dumps (ppo_config , indent = 4 )} </pre>" )
5699 self .logger .record ("parameters/env" , f"<pre>{ json .dumps (env_config , indent = 4 )} </pre>" )
57- self . logger . dump ( step = 0 )
100+
58101
59102 def _on_step (self ) -> bool :
60103 """
@@ -90,17 +133,17 @@ def _on_step(self) -> bool:
90133 self .final_distance = self .locals ["infos" ][0 ].get ("distance" , 0.0 )
91134 self .logger .record ("final_distance" , self .final_distance )
92135 self .logger .record ("valid_actions" , self .actions_info ["episode_valid_actions" ]* 100 / self .current_episode_length if self .current_episode_length > 0 else 0 )
93- self .logger .record ("n_invalid_topo" , self .actions_info ["episode_invalid_topo" ])
94- self .logger .record ("n_invalid_geo" , self .actions_info ["episode_invalid_geo" ])
95- self .logger .record ("nb_flip_cw" , self .actions_info ["nb_flip_cw" ])
96- self .logger .record ("nb_flip_cntcw" , self .actions_info ["nb_flip_cntcw" ])
97- self .logger .record ("nb_split" , self .actions_info ["nb_split" ])
98- self .logger .record ("nb_collapse" , self .actions_info ["nb_collapse" ])
99- self .logger .record ("nb_cleanup" , self .actions_info ["nb_cleanup" ])
100- self .logger .record ("invalid_flip" , self .actions_info ["nb_invalid_flip" ]* 100 / (self .actions_info ["nb_flip_cw" ]+ self .actions_info ["nb_flip_cntcw" ]) if (self .actions_info ["nb_flip_cw" ]+ self .actions_info ["nb_flip_cntcw" ]) > 0 else 0 )
101- self .logger .record ("invalid_split" , self .actions_info ["nb_invalid_split" ]* 100 / self .actions_info ["nb_split" ] if self .actions_info ["nb_split" ] > 0 else 0 )
102- self .logger .record ("invalid_collapse" , self .actions_info ["nb_invalid_collapse" ]* 100 / self .actions_info ["nb_collapse" ]if self .actions_info ["nb_collapse" ] > 0 else 0 )
103- self .logger .record ("invalid_cleanup" , self .actions_info ["nb_invalid_cleanup" ]* 100 / self .actions_info ["nb_cleanup" ]if self .actions_info ["nb_cleanup" ] > 0 else 0 )
136+ self .logger .record ("actions/ n_invalid_topo" , self .actions_info ["episode_invalid_topo" ])
137+ self .logger .record ("actions/ n_invalid_geo" , self .actions_info ["episode_invalid_geo" ])
138+ self .logger .record ("actions/ nb_flip_cw" , self .actions_info ["nb_flip_cw" ])
139+ self .logger .record ("actions/ nb_flip_cntcw" , self .actions_info ["nb_flip_cntcw" ])
140+ self .logger .record ("actions/ nb_split" , self .actions_info ["nb_split" ])
141+ self .logger .record ("actions/ nb_collapse" , self .actions_info ["nb_collapse" ])
142+ self .logger .record ("actions/ nb_cleanup" , self .actions_info ["nb_cleanup" ])
143+ self .logger .record ("actions/ invalid_flip" , self .actions_info ["nb_invalid_flip" ]* 100 / (self .actions_info ["nb_flip_cw" ]+ self .actions_info ["nb_flip_cntcw" ]) if (self .actions_info ["nb_flip_cw" ]+ self .actions_info ["nb_flip_cntcw" ]) > 0 else 0 )
144+ self .logger .record ("actions/ invalid_split" , self .actions_info ["nb_invalid_split" ]* 100 / self .actions_info ["nb_split" ] if self .actions_info ["nb_split" ] > 0 else 0 )
145+ self .logger .record ("actions/ invalid_collapse" , self .actions_info ["nb_invalid_collapse" ]* 100 / self .actions_info ["nb_collapse" ]if self .actions_info ["nb_collapse" ] > 0 else 0 )
146+ self .logger .record ("actions/ invalid_cleanup" , self .actions_info ["nb_invalid_cleanup" ]* 100 / self .actions_info ["nb_cleanup" ]if self .actions_info ["nb_cleanup" ] > 0 else 0 )
104147 self .logger .record ("episode_mesh_reward" , self .mesh_reward )
105148 self .logger .record ("episode_reward" , self .current_episode_reward )
106149 self .logger .record ("normalized_return" , self .normalized_return )
@@ -123,42 +166,97 @@ def _on_step(self) -> bool:
123166 def _on_training_end (self ) -> None :
124167 """
125168 Records policy evaluation results : before and after dataset images
169+ Save registry counts of observation in a csv file. Records analysis
126170 """
127- filename = "counts_PPO47 .json"
171+ filename = "counts_" + experiment_name + " .json"
128172 counts_registry = self .locals ["infos" ][0 ].get ("observation_count" , 0.0 )
129173 counts = counts_registry .counts
130174
131175 # Convertir les clés tuple en chaînes de caractères
132- counts_str_keys = {v : str (k ) for k , v in counts .items ()}
176+ counts_str_keys = [(v , str (k )) for k , v in counts .items ()]
177+ counts_values = list (counts .values ())
133178
134179 # Écriture dans un fichier JSON
135180 with open (filename , "w" ) as file :
136181 json .dump (counts_str_keys , file , indent = 4 )
137182
138183 print (f"Counts saved at { filename } " )
139184
185+ self .logger .record ("observation/n_observation" , len (counts_values ))
186+ self .logger .record ("observation/mean" , np .mean (counts_values ))
187+ self .logger .record ("observation/median" , np .median (counts_values ))
188+ self .logger .record ("observation/min" , np .min (counts_values ))
189+ self .logger .record ("observation/max" , np .max (counts_values ))
190+
191+ counts_values .sort ()
192+ figure , ax = plt .subplots ()
193+ ax .hist (counts_values , bins = 'auto' )
194+ ax .set_title ("Observation counts" )
195+ # Close the figure after logging it
196+ self .logger .record ("observation/counts" , Figure (figure , close = True ), exclude = ("stdout" , "log" , "json" , "csv" ))
197+ plt .close ()
198+
199+
140200 #mesh = read_gmsh("mesh_files/medium_quad.msh")
141201 dataset = [QM .random_mesh () for _ in range (9 )] # dataset of 9 meshes of size 30
142202 before = dataset_plt (dataset ) # plot the datasat as image
143203 length , wins , rewards , normalized_return , final_meshes = testPolicy (self .model , 10 , env_config , dataset ) # test model policy on the dataset
144204 after = dataset_plt (final_meshes )
145205 self .logger .record ("figures/before" , Figure (before , close = True ), exclude = ("stdout" , "log" ))
146206 self .logger .record ("figures/after" , Figure (after , close = True ), exclude = ("stdout" , "log" ))
147- self .logger .dump (step = 0 )
207+ self .logger .dump (step = self . num_timesteps )
148208
149209if __name__ == '__main__' :
150210
211+ experiment_name = "wandb_test"
212+ ppo_config_path = "model_RL/parameters/ppo_config.json"
213+ env_config_path = "environment/environment_config.json"
214+ eval_env_config_path = "environment/eval_environment_config.json"
215+ policy_saving_path = os .path .join ("training/policy_saved/quad/" , experiment_name )
216+ wandb_model_save_path = f"training/wandb_models/{ experiment_name } "
217+
218+ #Mesh datasets
219+ evaluation_mesh_file_path = "mesh_files/simple_quad.msh"
220+ training_mesh_file_path = "mesh_files/simple_quad.msh"
221+
222+
151223 # SEEDING
152224 seed = 1
153225 random .seed (seed )
154226 np .random .seed (seed )
155227 torch .manual_seed (seed )
156228 torch .backends .cudnn .deterministic = True
157229
158- with open ("model_RL/parameters/ppo_config.json" , "r" ) as f :
230+ # PARAMETERS CONFIGURATION
231+
232+ with open (ppo_config_path , "r" ) as f :
159233 ppo_config = json .load (f )
160- with open ("environment/environment_config.json" , "r" ) as f :
234+ with open (env_config_path , "r" ) as f :
161235 env_config = json .load (f )
236+ with open (eval_env_config_path , "r" ) as f :
237+ eval_env_config = json .load (f )
238+
239+ # WANDB
240+ run = wandb .init (
241+ project = "sb3" ,
242+ sync_tensorboard = True , # auto-upload sb3's tensorboard metrics
243+ save_code = True , # optional
244+ )
245+ # EVALUATION CALLBACKS
246+
247+ # Separate evaluation env
248+ eval_env = gym .make (
249+ eval_env_config ["env_name" ],
250+ mesh = read_gmsh (evaluation_mesh_file_path ),
251+ max_episode_steps = eval_env_config ["max_episode_steps" ],
252+ n_darts_selected = eval_env_config ["n_darts_selected" ],
253+ deep = eval_env_config ["deep" ],
254+ action_restriction = eval_env_config ["action_restriction" ],
255+ with_degree_obs = eval_env_config ["with_degree_observation" ]
256+ )
257+ # Stop training if there is no improvement after more than 3 evaluations
258+ stop_train_callback = StopTrainingOnNoModelImprovement (max_no_improvement_evals = 5 , min_evals = 5 , verbose = 1 )
259+ eval_callback = EvalCallback (eval_env , eval_freq = 500 , callback_after_eval = stop_train_callback , verbose = 1 )
162260
163261 # Create log dir
164262 log_dir = ppo_config ["tensorboard_log" ]
@@ -167,7 +265,7 @@ def _on_training_end(self) -> None:
167265 # Create the environment
168266 env = gym .make (
169267 env_config ["env_name" ],
170- # mesh = read_gmsh("../mesh_files/medium_quad.msh" ),
268+ mesh = read_gmsh (training_mesh_file_path ),
171269 max_episode_steps = env_config ["max_episode_steps" ],
172270 n_darts_selected = env_config ["n_darts_selected" ],
173271 deep = env_config ["deep" ],
@@ -190,6 +288,7 @@ def _on_training_end(self) -> None:
190288 )
191289
192290 print ("-----------Starting learning-----------" )
193- model .learn (total_timesteps = ppo_config ["total_timesteps" ], callback = TensorboardCallback (model ))
291+ model .learn (total_timesteps = ppo_config ["total_timesteps" ], tb_log_name = experiment_name , callback = [ HParamCallback (), WandbCallback ( model_save_path = wandb_model_save_path ), TensorboardCallback (model ), eval_callback ], progress_bar = True )
194292 print ("-----------Learning ended------------" )
195- model .save ("training/policy_saved/quad/4-actions-quad-rand_simple-PPO47" )
293+ model .save (policy_saving_path )
294+ run .finish ()
0 commit comments