Skip to content

Commit 79c3522

Browse files
committed
Fix issue with exploit file
1 parent 49663c6 commit 79c3522

File tree

2 files changed

+65
-38
lines changed

2 files changed

+65
-38
lines changed

training/exploit_PPO_perso.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,42 @@
1-
from numpy import ndarray
2-
3-
import gymnasium as gym
41
import json
52
import torch
3+
import copy
4+
import numpy as np
5+
import gymnasium as gym
6+
import yaml
7+
8+
from tqdm import tqdm
9+
from numpy import ndarray
610
from torch.distributions import Categorical
7-
from model_RL.PPO_model_pers import Actor
811

12+
from mesh_model.mesh_analysis.quadmesh_analysis import QuadMeshOldAnalysis
13+
from mesh_model.mesh_analysis.trimesh_analysis import TriMeshOldAnalysis, TriMeshQualityAnalysis
14+
from mesh_model.mesh_struct.mesh_elements import Dart
915
from mesh_model.mesh_struct.mesh import Mesh
1016
from mesh_model.reader import read_gmsh
17+
from model_RL.PPO_model_pers import Actor
18+
1119
from view.mesh_plotter.create_plots import plot_test_results
1220
from view.mesh_plotter.mesh_plots import plot_dataset
1321
from environment.actions.smoothing import smoothing_mean
14-
import mesh_model.random_quadmesh as QM
22+
1523
from environment.gymnasium_envs.quadmesh_env.envs.quadmesh import QuadMeshEnv
16-
import numpy as np
17-
import copy
18-
from tqdm import tqdm
24+
from environment.gymnasium_envs.trimesh_full_env.envs.trimesh import TriMeshEnvFull
1925

26+
import mesh_model.random_quadmesh as QM
2027

2128
def testPolicy(
2229
actor,
2330
n_eval_episodes: int,
24-
env_config,
31+
config,
2532
dataset: list[Mesh]
2633
) -> tuple[ndarray, ndarray, ndarray, ndarray, list[Mesh]]:
2734
"""
2835
Tests policy on each mesh of a dataset with n_eval_episodes.
29-
:param policy: the policy to test
36+
:param actor: the policy to test
3037
:param n_eval_episodes: number of evaluation episodes on each mesh
38+
:param config: configuration
3139
:param dataset: list of mesh objects
32-
:param max_steps: max steps to evaluate
3340
:return: average length of evaluation episodes, number of wins,average reward per mesh, dataset with the modified meshes
3441
"""
3542
print('Testing policy')
@@ -41,14 +48,15 @@ def testPolicy(
4148
for i, mesh in tqdm(enumerate(dataset, 1)):
4249
best_mesh = mesh
4350
env = gym.make(
44-
env_config["env_name"],
45-
max_episode_steps=30,
46-
mesh = mesh,
47-
n_darts_selected=env_config["n_darts_selected"],
48-
deep= env_config["deep"],
49-
action_restriction=env_config["action_restriction"],
50-
with_degree_obs=env_config["with_degree_observation"],
51-
render_mode="human"
51+
config["env"]["env_id"],
52+
max_episode_steps=config["env"]["max_episode_steps"],
53+
mesh=mesh,
54+
#mesh_size = 30,
55+
n_darts_selected=config["env"]["n_darts_selected"],
56+
deep=config["env"]["deep"],
57+
action_restriction=config["env"]["action_restriction"],
58+
with_quality_obs=config["env"]["with_quality_observation"],
59+
render_mode=config["env"]["render_mode"],
5260
)
5361
for _ in range(n_eval_episodes):
5462
terminated = False
@@ -62,8 +70,8 @@ def testPolicy(
6270
dist = Categorical(pmf)
6371
action = dist.sample()
6472
action = action.tolist()
65-
action_dart = int(action / 4)
66-
action_type = action % 4
73+
action_dart = int(action / config["ppo"]["n_actions"])
74+
action_type = action % config["ppo"]["n_actions"]
6775
gymnasium_action = [action_type, action_dart]
6876
if action is None:
6977
env.terminal = True
@@ -73,7 +81,7 @@ def testPolicy(
7381
ep_length += 1
7482
if terminated:
7583
nb_wins[i-1] += 1
76-
if isBetterMesh(best_mesh, info['mesh']):
84+
if isBetterMesh(best_mesh, info['mesh'], config["env"]["analysis_type"]):
7785
best_mesh = copy.deepcopy(info['mesh'])
7886
avg_length[i-1] += ep_length
7987
avg_mesh_rewards[i-1] += ep_mesh_rewards
@@ -89,37 +97,57 @@ def isBetterPolicy(actual_best_policy, policy_to_test):
8997
if actual_best_policy is None:
9098
return True
9199

92-
def isBetterMesh(best_mesh, actual_mesh):
93-
if best_mesh is None or global_score(best_mesh)[1] > global_score(actual_mesh)[1]:
100+
def isBetterMesh(best_mesh, actual_mesh, analysis_type):
101+
tri = False
102+
for d_info in actual_mesh.dart_info:
103+
if d_info[0]>=0:
104+
d = Dart(actual_mesh, d_info[0])
105+
if d == ((d.get_beta(1)).get_beta(1)).get_beta(1):
106+
tri = True
107+
else:
108+
tri = False
109+
break
110+
if tri:
111+
if analysis_type == "old":
112+
ma_best_mesh = TriMeshOldAnalysis(best_mesh)
113+
ma_actual_mesh = TriMeshOldAnalysis(actual_mesh)
114+
else:
115+
ma_best_mesh = TriMeshQualityAnalysis(best_mesh)
116+
ma_actual_mesh = TriMeshQualityAnalysis(actual_mesh)
117+
else:
118+
ma_best_mesh = QuadMeshOldAnalysis(best_mesh)
119+
ma_actual_mesh = QuadMeshOldAnalysis(actual_mesh)
120+
if best_mesh is None or ma_best_mesh.global_score()[1] > ma_actual_mesh.global_score()[1]:
94121
return True
95122
else:
96123
return False
97124

98125

99126
if __name__ == '__main__':
100127

101-
102128
#Create a dataset of 9 meshes
103-
mesh = read_gmsh("../mesh_files/medium_quad.msh")
129+
mesh = read_gmsh("../mesh_files/t1_tri.msh")
104130
dataset = [mesh for _ in range(9)]
105-
with open("../environment/old_files/environment_config.json", "r") as f:
106-
env_config = json.load(f)
131+
with open("../training/config/trimesh_config_PPO_perso.yaml", "r") as f:
132+
config = yaml.safe_load(f)
107133
plot_dataset(dataset)
108134

109135
env = gym.make(
110-
env_config["env_name"],
136+
config["env"]["env_id"],
137+
max_episode_steps=config["env"]["max_episode_steps"],
111138
mesh=mesh,
112-
max_episode_steps=env_config["max_episode_steps"],
113-
n_darts_selected=env_config["n_darts_selected"],
114-
deep=env_config["deep"],
115-
action_restriction=env_config["action_restriction"],
116-
with_degree_obs=env_config["with_degree_observation"]
139+
# mesh_size = 30,
140+
n_darts_selected=config["env"]["n_darts_selected"],
141+
deep=config["env"]["deep"],
142+
action_restriction=config["env"]["action_restriction"],
143+
with_quality_obs=config["env"]["with_quality_observation"],
144+
render_mode=config["env"]["render_mode"],
117145
)
118146

119147
#Load the model
120-
actor = Actor(env, 10*8, 4*10, lr=0.0001)
121-
actor.load_state_dict(torch.load('policy_saved/quad-perso/medium_quad_perso-2.pth'))
122-
avg_steps, avg_wins, avg_rewards, normalized_return, final_meshes = testPolicy(actor, 15, env_config, dataset)
148+
actor = Actor(env, config["env"]["obs_size"], config["ppo"]["n_actions"], n_darts_observed=config["env"]["n_darts_selected"], lr=0.0001)
149+
actor.load_state_dict(torch.load('policy_saved/tri-perso/TEST-Exploit.pth'))
150+
avg_steps, avg_wins, avg_rewards, normalized_return, final_meshes = testPolicy(actor, 15, config, dataset)
123151

124152
plot_test_results(avg_rewards, avg_wins, avg_steps, normalized_return)
125153
plot_dataset(final_meshes)

training/exploit_SB3_policy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def isBetterMesh(best_mesh, actual_mesh, analysis_type):
118118

119119
if __name__ == '__main__':
120120

121-
122121
#Create a dataset of 9 meshes
123122
mesh = read_gmsh("../mesh_files/tri-star.msh")
124123
# ma = TriMeshQualityAnalysis(mesh)

0 commit comments

Comments
 (0)