1- from numpy import ndarray
2-
3- import gymnasium as gym
41import json
52import torch
3+ import copy
4+ import numpy as np
5+ import gymnasium as gym
6+ import yaml
7+
8+ from tqdm import tqdm
9+ from numpy import ndarray
610from torch .distributions import Categorical
7- from model_RL .PPO_model_pers import Actor
811
12+ from mesh_model .mesh_analysis .quadmesh_analysis import QuadMeshOldAnalysis
13+ from mesh_model .mesh_analysis .trimesh_analysis import TriMeshOldAnalysis , TriMeshQualityAnalysis
14+ from mesh_model .mesh_struct .mesh_elements import Dart
915from mesh_model .mesh_struct .mesh import Mesh
1016from mesh_model .reader import read_gmsh
17+ from model_RL .PPO_model_pers import Actor
18+
1119from view .mesh_plotter .create_plots import plot_test_results
1220from view .mesh_plotter .mesh_plots import plot_dataset
1321from environment .actions .smoothing import smoothing_mean
14- import mesh_model . random_quadmesh as QM
22+
1523from environment .gymnasium_envs .quadmesh_env .envs .quadmesh import QuadMeshEnv
16- import numpy as np
17- import copy
18- from tqdm import tqdm
24+ from environment .gymnasium_envs .trimesh_full_env .envs .trimesh import TriMeshEnvFull
1925
26+ import mesh_model .random_quadmesh as QM
2027
2128def testPolicy (
2229 actor ,
2330 n_eval_episodes : int ,
24- env_config ,
31+ config ,
2532 dataset : list [Mesh ]
2633) -> tuple [ndarray , ndarray , ndarray , ndarray , list [Mesh ]]:
2734 """
2835 Tests policy on each mesh of a dataset with n_eval_episodes.
29- :param policy : the policy to test
36+ :param actor : the policy to test
3037 :param n_eval_episodes: number of evaluation episodes on each mesh
38+ :param config: configuration
3139 :param dataset: list of mesh objects
32- :param max_steps: max steps to evaluate
3340 :return: average length of evaluation episodes, number of wins,average reward per mesh, dataset with the modified meshes
3441 """
3542 print ('Testing policy' )
@@ -41,14 +48,15 @@ def testPolicy(
4148 for i , mesh in tqdm (enumerate (dataset , 1 )):
4249 best_mesh = mesh
4350 env = gym .make (
44- env_config ["env_name" ],
45- max_episode_steps = 30 ,
46- mesh = mesh ,
47- n_darts_selected = env_config ["n_darts_selected" ],
48- deep = env_config ["deep" ],
49- action_restriction = env_config ["action_restriction" ],
50- with_degree_obs = env_config ["with_degree_observation" ],
51- render_mode = "human"
51+ config ["env" ]["env_id" ],
52+ max_episode_steps = config ["env" ]["max_episode_steps" ],
53+ mesh = mesh ,
54+ #mesh_size = 30,
55+ n_darts_selected = config ["env" ]["n_darts_selected" ],
56+ deep = config ["env" ]["deep" ],
57+ action_restriction = config ["env" ]["action_restriction" ],
58+ with_quality_obs = config ["env" ]["with_quality_observation" ],
59+ render_mode = config ["env" ]["render_mode" ],
5260 )
5361 for _ in range (n_eval_episodes ):
5462 terminated = False
@@ -62,8 +70,8 @@ def testPolicy(
6270 dist = Categorical (pmf )
6371 action = dist .sample ()
6472 action = action .tolist ()
65- action_dart = int (action / 4 )
66- action_type = action % 4
73+ action_dart = int (action / config [ "ppo" ][ "n_actions" ] )
74+ action_type = action % config [ "ppo" ][ "n_actions" ]
6775 gymnasium_action = [action_type , action_dart ]
6876 if action is None :
6977 env .terminal = True
@@ -73,7 +81,7 @@ def testPolicy(
7381 ep_length += 1
7482 if terminated :
7583 nb_wins [i - 1 ] += 1
76- if isBetterMesh (best_mesh , info ['mesh' ]):
84+ if isBetterMesh (best_mesh , info ['mesh' ], config [ "env" ][ "analysis_type" ] ):
7785 best_mesh = copy .deepcopy (info ['mesh' ])
7886 avg_length [i - 1 ] += ep_length
7987 avg_mesh_rewards [i - 1 ] += ep_mesh_rewards
@@ -89,37 +97,57 @@ def isBetterPolicy(actual_best_policy, policy_to_test):
8997 if actual_best_policy is None :
9098 return True
9199
92- def isBetterMesh (best_mesh , actual_mesh ):
93- if best_mesh is None or global_score (best_mesh )[1 ] > global_score (actual_mesh )[1 ]:
100+ def isBetterMesh (best_mesh , actual_mesh , analysis_type ):
101+ tri = False
102+ for d_info in actual_mesh .dart_info :
103+ if d_info [0 ]>= 0 :
104+ d = Dart (actual_mesh , d_info [0 ])
105+ if d == ((d .get_beta (1 )).get_beta (1 )).get_beta (1 ):
106+ tri = True
107+ else :
108+ tri = False
109+ break
110+ if tri :
111+ if analysis_type == "old" :
112+ ma_best_mesh = TriMeshOldAnalysis (best_mesh )
113+ ma_actual_mesh = TriMeshOldAnalysis (actual_mesh )
114+ else :
115+ ma_best_mesh = TriMeshQualityAnalysis (best_mesh )
116+ ma_actual_mesh = TriMeshQualityAnalysis (actual_mesh )
117+ else :
118+ ma_best_mesh = QuadMeshOldAnalysis (best_mesh )
119+ ma_actual_mesh = QuadMeshOldAnalysis (actual_mesh )
120+ if best_mesh is None or ma_best_mesh .global_score ()[1 ] > ma_actual_mesh .global_score ()[1 ]:
94121 return True
95122 else :
96123 return False
97124
98125
99126if __name__ == '__main__' :
100127
101-
102128 #Create a dataset of 9 meshes
103- mesh = read_gmsh ("../mesh_files/medium_quad .msh" )
129+ mesh = read_gmsh ("../mesh_files/t1_tri .msh" )
104130 dataset = [mesh for _ in range (9 )]
105- with open ("../environment/old_files/environment_config.json " , "r" ) as f :
106- env_config = json . load (f )
131+ with open ("../training/config/trimesh_config_PPO_perso.yaml " , "r" ) as f :
132+ config = yaml . safe_load (f )
107133 plot_dataset (dataset )
108134
109135 env = gym .make (
110- env_config ["env_name" ],
136+ config ["env" ]["env_id" ],
137+ max_episode_steps = config ["env" ]["max_episode_steps" ],
111138 mesh = mesh ,
112- max_episode_steps = env_config ["max_episode_steps" ],
113- n_darts_selected = env_config ["n_darts_selected" ],
114- deep = env_config ["deep" ],
115- action_restriction = env_config ["action_restriction" ],
116- with_degree_obs = env_config ["with_degree_observation" ]
139+ # mesh_size = 30,
140+ n_darts_selected = config ["env" ]["n_darts_selected" ],
141+ deep = config ["env" ]["deep" ],
142+ action_restriction = config ["env" ]["action_restriction" ],
143+ with_quality_obs = config ["env" ]["with_quality_observation" ],
144+ render_mode = config ["env" ]["render_mode" ],
117145 )
118146
119147 #Load the model
120- actor = Actor (env , 10 * 8 , 4 * 10 , lr = 0.0001 )
121- actor .load_state_dict (torch .load ('policy_saved/quad -perso/medium_quad_perso-2 .pth' ))
122- avg_steps , avg_wins , avg_rewards , normalized_return , final_meshes = testPolicy (actor , 15 , env_config , dataset )
148+ actor = Actor (env , config [ "env" ][ "obs_size" ], config [ "ppo" ][ "n_actions" ], n_darts_observed = config [ "env" ][ "n_darts_selected" ] , lr = 0.0001 )
149+ actor .load_state_dict (torch .load ('policy_saved/tri -perso/TEST-Exploit .pth' ))
150+ avg_steps , avg_wins , avg_rewards , normalized_return , final_meshes = testPolicy (actor , 15 , config , dataset )
123151
124152 plot_test_results (avg_rewards , avg_wins , avg_steps , normalized_return )
125153 plot_dataset (final_meshes )
0 commit comments