Skip to content

Commit 018b517

Browse files
committed
Fix policy evaluation problem
1 parent d700a80 commit 018b517

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

environment/gymnasium_envs/trimesh_full_env/envs/trimesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def reset(self, seed=None, options=None):
6969
# We need the following line to seed self.np_random
7070
super().reset(seed=seed)
7171
if options is not None:
72-
self.mesh = options.get("mesh", self.mesh)
72+
self.mesh = options['mesh']
7373
else:
7474
self.mesh = random_mesh(self.mesh_size)
7575
self.nb_darts = len(self.mesh.dart_info)

exploit_SB3_policy.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from numpy import ndarray
22

33
import gymnasium as gym
4+
import json
45

56
from environment.gymnasium_envs.trimesh_flip_env import TriMeshEnvFlip
67
from environment.gymnasium_envs.trimesh_full_env import TriMeshEnvFull
@@ -50,9 +51,9 @@ def testPolicy(
5051
truncated = False
5152
ep_mesh_rewards: int = 0
5253
ep_length: int = 0
53-
obs, info = env.reset(options={"mesh": mesh})
54+
obs, info = env.reset(options={"mesh": copy.deepcopy(mesh)})
5455
while terminated == False and truncated == False:
55-
action, _states = model.predict(obs, deterministic=True)
56+
action, _states = model.predict(obs, deterministic=False)
5657
if action is None:
5758
env.terminal = True
5859
break
@@ -65,7 +66,7 @@ def testPolicy(
6566
best_mesh = copy.deepcopy(info['mesh'])
6667
avg_length[i-1] += ep_length
6768
avg_mesh_rewards[i-1] += ep_mesh_rewards
68-
avg_normalized_return[i-1] += ep_mesh_rewards/info['mesh_ideal_rewards']
69+
avg_normalized_return[i-1] += 0 if info['mesh_ideal_rewards'] == 0 else ep_mesh_rewards/info['mesh_ideal_rewards']
6970
final_meshes.append(best_mesh)
7071
avg_length[i-1] = avg_length[i-1]/n_eval_episodes
7172
avg_mesh_rewards[i-1] = avg_mesh_rewards[i-1]/n_eval_episodes
@@ -83,13 +84,14 @@ def isBetterMesh(best_mesh, actual_mesh):
8384
else:
8485
return False
8586

87+
8688
"""
8789
dataset = [TM.random_mesh(30) for _ in range(9)]
8890
with open("environment/parameters/environment_config.json", "r") as f:
8991
env_config = json.load(f)
9092
plot_dataset(dataset)
91-
model = PPO.load("policy_saved/final-2.zip")
92-
avg_steps, avg_wins, avg_rewards, avg_normalized_return, final_meshes = testPolicy(model, 5, env_config, dataset)
93+
model = PPO.load("policy_saved/final/final-PPO-3.zip")
94+
avg_steps, avg_wins, avg_rewards, avg_normalized_return, final_meshes = testPolicy(model, 10, env_config, dataset)
9395
9496
plot_test_results(avg_rewards, avg_wins, avg_steps, avg_normalized_return)
9597
plot_dataset(final_meshes)

mesh_display.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@ def get_scores(self):
3535
Calculates the irregularities of each node and the real and ideal score of the mesh
3636
:return: a list of three elements (nodes_score, mesh_score, ideal_mesh_score)
3737
"""
38-
nodes_score, mesh_score, ideal_mesh_score = global_score(self.mesh)
38+
nodes_score, mesh_score, ideal_mesh_score, adjacency = global_score(self.mesh)
3939
return [nodes_score, mesh_score, ideal_mesh_score]

policy_saved/final/final-PPO-3.zip

336 KB
Binary file not shown.

0 commit comments

Comments
 (0)