1+ from __future__ import annotations
2+
3+ import os
4+ import json
5+
6+
7+ import mesh_model .random_quadmesh as QM
8+ from environment .gymnasium_envs .quadmesh_env import QuadMeshEnv
9+ from plots .mesh_plotter import dataset_plt
10+ from exploit_SB3_policy import testPolicy
11+ from stable_baselines3 import PPO ,SAC
12+ from stable_baselines3 .common .env_checker import check_env
13+ from stable_baselines3 .common .callbacks import BaseCallback
14+ from stable_baselines3 .common .logger import Figure
15+
16+ import gymnasium as gym
17+
18+ class TensorboardCallback (BaseCallback ):
19+ """
20+ Custom callback for plotting additional values in tensorboard.
21+ """
22+ def __init__ (self , model , verbose = 0 ):
23+ super ().__init__ (verbose )
24+ self .model = model
25+ self .episode_rewards = []
26+ self .mesh_reward = 0
27+ self .current_episode_reward = 0
28+ self .episode_count = 0
29+ self .current_episode_length = 0
30+ self .actions_info = {
31+ "episode_valid_actions" : 0 ,
32+ "episode_invalid_topo" : 0 ,
33+ "episode_invalid_geo" : 0 ,
34+ "nb_flip" : 0 ,
35+ "nb_split" : 0 ,
36+ "nb_collapse" : 0 ,
37+ "nb_cleanup" : 0 ,
38+ "nb_invalid_flip" : 0 ,
39+ "nb_invalid_split" : 0 ,
40+ "nb_invalid_collapse" : 0 ,
41+ "nb_invalid_cleanup" : 0 ,
42+ }
43+ self .final_distance = 0
44+ self .normalized_return = 0
45+
46+ def _on_training_start (self ) -> None :
47+ """
48+ Record PPO parameters and environment configuration at the training start.
49+ """
50+ self .logger .record ("parameters/ppo" , f"<pre>{ json .dumps (ppo_config , indent = 4 )} </pre>" )
51+ self .logger .record ("parameters/env" , f"<pre>{ json .dumps (env_config , indent = 4 )} </pre>" )
52+ self .logger .dump (step = 0 )
53+
54+ def _on_step (self ) -> bool :
55+ """
56+ Record different learning variables to monitor
57+ """
58+ self .current_episode_reward += self .locals ["rewards" ][0 ]
59+ self .current_episode_length += 1
60+
61+ self .actions_info ["episode_valid_actions" ] += self .locals ["infos" ][0 ].get ("valid_action" , 0.0 )
62+ self .actions_info ["episode_invalid_topo" ] += self .locals ["infos" ][0 ].get ("invalid_topo" , 0.0 )
63+ self .actions_info ["episode_invalid_geo" ] += self .locals ["infos" ][0 ].get ("invalid_geo" , 0.0 )
64+ self .actions_info ["nb_flip" ] += self .locals ["infos" ][0 ].get ("flip" , 0.0 )
65+ self .actions_info ["nb_split" ] += self .locals ["infos" ][0 ].get ("split" , 0.0 )
66+ self .actions_info ["nb_collapse" ] += self .locals ["infos" ][0 ].get ("collapse" , 0.0 )
67+ self .actions_info ["nb_cleanup" ] += self .locals ["infos" ][0 ].get ("cleanup" , 0.0 )
68+ self .actions_info ["nb_invalid_flip" ] += self .locals ["infos" ][0 ].get ("invalid_flip" , 0.0 )
69+ self .actions_info ["nb_invalid_split" ] += self .locals ["infos" ][0 ].get ("invalid_split" , 0.0 )
70+ self .actions_info ["nb_invalid_collapse" ] += self .locals ["infos" ][0 ].get ("invalid_collapse" , 0.0 )
71+ self .actions_info ["nb_invalid_cleanup" ] += self .locals ["infos" ][0 ].get ("invalid_cleanup" , 0.0 )
72+
73+ self .mesh_reward += self .locals ["infos" ][0 ].get ("mesh_reward" , 0.0 )
74+
75+ # When the episode is over
76+ if self .locals ["dones" ][0 ]:
77+ self .episode_rewards .append (self .current_episode_reward ) # global rewards obtained during the episode
78+ mesh_ideal_reward = self .locals ["infos" ][0 ].get ("mesh_ideal_rewards" , 0.0 ) # maximum achievable reward
79+ if mesh_ideal_reward > 0 :
80+ self .normalized_return = self .mesh_reward / mesh_ideal_reward
81+ else :
82+ self .normalized_return = 0
83+
84+ self .final_distance = self .locals ["infos" ][0 ].get ("distance" , 0.0 )
85+ self .logger .record ("final_distance" , self .final_distance )
86+ self .logger .record ("valid_actions" , self .actions_info ["episode_valid_actions" ]* 100 / self .current_episode_length if self .current_episode_length > 0 else 0 )
87+ self .logger .record ("n_invalid_topo" , self .actions_info ["episode_invalid_topo" ])
88+ self .logger .record ("n_invalid_geo" , self .actions_info ["episode_invalid_geo" ])
89+ self .logger .record ("nb_flip" , self .actions_info ["nb_flip" ])
90+ self .logger .record ("nb_split" , self .actions_info ["nb_split" ])
91+ self .logger .record ("nb_collapse" , self .actions_info ["nb_collapse" ])
92+ self .logger .record ("nb_cleanup" , self .actions_info ["nb_cleanup" ])
93+ self .logger .record ("invalid_flip" , self .actions_info ["nb_invalid_flip" ]* 100 / self .actions_info ["nb_flip" ] if self .actions_info ["nb_flip" ] > 0 else 0 )
94+ self .logger .record ("invalid_split" , self .actions_info ["nb_invalid_split" ]* 100 / self .actions_info ["nb_split" ] if self .actions_info ["nb_split" ] > 0 else 0 )
95+ self .logger .record ("invalid_collapse" , self .actions_info ["nb_invalid_collapse" ]* 100 / self .actions_info ["nb_collapse" ]if self .actions_info ["nb_collapse" ] > 0 else 0 )
96+ self .logger .record ("invalid_cleanup" , self .actions_info ["nb_invalid_cleanup" ]* 100 / self .actions_info ["nb_cleanup" ]if self .actions_info ["nb_cleanup" ] > 0 else 0 )
97+ self .logger .record ("episode_mesh_reward" , self .mesh_reward )
98+ self .logger .record ("episode_reward" , self .current_episode_reward )
99+ self .logger .record ("normalized_return" , self .normalized_return )
100+ self .logger .record ("episode_length" , self .current_episode_length )
101+
102+ is_success = self .locals ["infos" ][0 ].get ("is_success" , 0.0 ) # Default value: 0.0
103+ self .logger .record ("episode_success" , is_success )
104+
105+ self .logger .dump (step = self .episode_count )
106+ self .current_episode_reward = 0 # resets global episode reward
107+ self .mesh_reward = 0 # resets mesh episode reward
108+ self .current_episode_length = 0
109+ #reset actions info
110+ for key in self .actions_info .keys ():
111+ self .actions_info [key ] = 0
112+ self .episode_count += 1 # Increment episode counter
113+
114+ return True
115+
116+ def _on_training_end (self ) -> None :
117+ """
118+ Records policy evaluation results : before and after dataset images
119+ """
120+ dataset = [QM .random_mesh () for _ in range (9 )] # dataset of 9 meshes of size 30
121+ before = dataset_plt (dataset ) # plot the datasat as image
122+ length , wins , rewards , normalized_return , final_meshes = testPolicy (self .model , 10 , env_config , dataset ) # test model policy on the dataset
123+ after = dataset_plt (final_meshes )
124+ self .logger .record ("figures/before" , Figure (before , close = True ), exclude = ("stdout" , "log" ))
125+ self .logger .record ("figures/after" , Figure (after , close = True ), exclude = ("stdout" , "log" ))
126+ self .logger .dump (step = 0 )
127+
128+
129+ with open ("../model_RL/parameters/ppo_config.json" , "r" ) as f :
130+ ppo_config = json .load (f )
131+ with open ("../environment/parameters/environment_config.json" , "r" ) as f :
132+ env_config = json .load (f )
133+
134+ # Create log dir
135+ log_dir = ppo_config ["tensorboard_log" ]
136+ os .makedirs (log_dir , exist_ok = True )
137+
138+ # Create the environment
139+ env = gym .make (
140+ env_config ["env_name" ],
141+ max_episode_steps = env_config ["max_episode_steps" ],
142+ n_darts_selected = env_config ["n_darts_selected" ],
143+ deep = env_config ["deep" ],
144+ action_restriction = env_config ["action_restriction" ],
145+ with_degree_obs = env_config ["with_degree_observation" ]
146+ )
147+
148+ check_env (env , warn = True )
149+
150+ model = PPO (
151+ policy = ppo_config ["policy" ],
152+ env = env ,
153+ n_steps = ppo_config ["n_steps" ],
154+ n_epochs = ppo_config ["n_epochs" ],
155+ batch_size = ppo_config ["batch_size" ],
156+ learning_rate = ppo_config ["learning_rate" ],
157+ gamma = ppo_config ["gamma" ],
158+ verbose = ppo_config ["verbose" ],
159+ tensorboard_log = log_dir
160+ )
161+
162+ print ("-----------Starting learning-----------" )
163+ model .learn (total_timesteps = ppo_config ["total_timesteps" ], callback = TensorboardCallback (model ))
164+ print ("-----------Learning ended------------" )
165+ model .save ("policy_saved/quad/test3" )
0 commit comments