Skip to content

Commit 6c69983

Browse files
committed
Fix issues
1 parent ed50fd8 commit 6c69983

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

environment/gymnasium_envs/quadmesh_env/envs/quadmesh.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ class QuadMeshEnv(gym.Env):
5151
def __init__(
5252
self,
5353
mesh=None,
54-
max_episode_steps: int =50,
55-
n_darts_selected: int =20,
56-
deep: int =6,
54+
max_episode_steps: int = 50,
55+
n_darts_selected: int = 20,
56+
deep: int = 6,
5757
render_mode: Optional[str] = None,
58-
with_degree_obs: bool =True,
59-
action_restriction: bool =False,
60-
obs_count: bool=False,
58+
with_degree_obs: bool = True,
59+
action_restriction: bool = False,
60+
obs_count: bool = False,
6161
) -> None:
6262

6363
assert render_mode is None or render_mode in self.metadata["render_modes"]

training/exploit_PPO_perso.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch.distributions import Categorical
77
from model_RL.PPO_model_pers import Actor
8-
from stable_baselines3 import PPO
8+
99
from mesh_model.mesh_analysis.global_mesh_analysis import global_score
1010
from mesh_model.mesh_struct.mesh import Mesh
1111
from mesh_model.reader import read_gmsh
@@ -57,7 +57,7 @@ def testPolicy(
5757
ep_mesh_rewards: int = 0
5858
ep_length: int = 0
5959
observation, info = env.reset(options={"mesh": copy.deepcopy(mesh)})
60-
while terminated == False and truncated == False:
60+
while terminated is False and truncated is False:
6161
obs = torch.tensor(observation.flatten(), dtype=torch.float32)
6262
pmf = actor.forward(obs)
6363
dist = Categorical(pmf)
@@ -96,6 +96,7 @@ def isBetterMesh(best_mesh, actual_mesh):
9696
else:
9797
return False
9898

99+
99100
if __name__ == '__main__':
100101

101102

training/exploit_trimesh.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def exploit():
3232
plot_test_results(avg_rewards, avg_wins, avg_steps)
3333
plot_dataset(final_meshes)
3434

35+
3536
if __name__ == '__main__':
37+
3638
mesh = read_gmsh("../mesh_files/t1_quad.msh")
3739

3840
#Create a dataset of 9 meshes
@@ -56,7 +58,7 @@ def exploit():
5658
actor.load_state_dict(torch.load('policy_saved/actor_network.pth'))
5759
avg_steps, avg_wins, avg_rewards, final_meshes = testPolicy(actor, 15, dataset, 20)
5860

59-
plot_test_results(avg_rewards, avg_wins, avg_steps)
61+
plot_test_results(avg_rewards, avg_wins, avg_steps, avg_wins)
6062
plot_dataset(final_meshes)
6163
for m in final_meshes:
6264
smoothing_mean(m)

0 commit comments

Comments
 (0)