Skip to content

Commit 01042a0

Browse files
committed
Changes to quad training scripts.
Correction of an error in the PPO_perso model Added two-way flip: clockwise and counterclockwise
1 parent 45678f7 commit 01042a0

File tree

13 files changed

+374
-132
lines changed

13 files changed

+374
-132
lines changed

environment/gymnasium_envs/quadmesh_env/envs/quadmesh.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def __init__(self, mesh=None, max_episode_steps=30, n_darts_selected=20, deep=6,
4646
self.darts_selected = [] # darts id observed
4747
self.deep = deep*2 if self.degree_observation else deep
4848
self.observation_space = spaces.Box(
49-
low=-6, # nodes min degree : -15
50-
high=2, # nodes max degree : 15
49+
low=-6, # nodes min degree : -6
50+
high=2, # nodes max degree : 2
5151
shape=(self.n_darts_selected, deep),
5252
dtype=np.int64
5353
)
@@ -102,8 +102,7 @@ def _get_info(self, terminated, valid_act, action, mesh_reward):
102102
"split": 1.0 if action[0]==Actions.SPLIT.value else 0.0,
103103
"collapse": 1.0 if action[0]==Actions.COLLAPSE.value else 0.0,
104104
"cleanup": 1.0 if action[0]==Actions.CLEANUP.value else 0.0,
105-
"invalid_flip_cw": 1.0 if action[0]==Actions.FLIP_CW.value and not valid_action else 0.0,
106-
"invalid_flip_cntcw": 1.0 if action[0]==Actions.FLIP_CNTCW.value and not valid_action else 0.0,
105+
"invalid_flip": 1.0 if (action[0]==Actions.FLIP_CW.value or action[0]==Actions.FLIP_CNTCW.value) and not valid_action else 0.0,
107106
"invalid_split": 1.0 if action[0]==Actions.SPLIT.value and not valid_action else 0.0,
108107
"invalid_collapse": 1.0 if action[0]==Actions.COLLAPSE.value and not valid_action else 0.0,
109108
"invalid_cleanup": 1.0 if action[0]==Actions.CLEANUP.value and not valid_action else 0.0,

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22

33
from user_game import user_game
4-
from training.train_quadmesh import train
4+
from training.train import train
55
from training.exploit import exploit
66
#from mesh_model.reader import read_gmsh
77
#from mesh_display import MeshDisplay

mesh_model/mesh_analysis/global_mesh_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def test_degree(n: Node) -> bool:
233233
:param n: a Node
234234
:return: True if the degree is lower than 10, False otherwise
235235
"""
236-
if degree(n) > 10:
236+
if degree(n) >= 10:
237237
return False
238238
else:
239239
return True

mesh_model/mesh_analysis/quadmesh_analysis.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,43 @@
22

33
from mesh_model.mesh_struct.mesh_elements import Dart, Node, Face
44
from mesh_model.mesh_struct.mesh import Mesh
5-
from mesh_model.mesh_analysis.global_mesh_analysis import test_degree, on_boundary, adjacent_faces_id
5+
from mesh_model.mesh_analysis.global_mesh_analysis import test_degree, on_boundary, adjacent_faces_id, degree
6+
7+
FLIP_CW = 0 # flip clockwise
8+
FLIP_CCW = 1 # flip counterclockwise
9+
SPLIT = 2
10+
COLLAPSE = 3
11+
CLEANUP = 4
12+
TEST_ALL = 5 # test if all actions are valid
13+
ONE_VALID = 6 # test if at least one action is valid
614

715

816
def isValidAction(mesh: Mesh, dart_id: int, action: int) -> (bool, bool):
9-
flip_ccw = 0
10-
split = 1
11-
collapse = 2
12-
cleanup =3
13-
test_all = 4
14-
one_valid = 5
17+
"""
18+
Test if an action is valid. You can select the ype of action between {flip clockwise, flip counterclockwise, split, collapse, cleanup, all action, one action no matter wich one}. :param mesh:
19+
:param mesh: a mesh
20+
:param dart_id: a dart on which to test the action
21+
:param action: an action type
22+
:return:
23+
"""
1524
d = Dart(mesh, dart_id)
1625
if d.get_beta(2) is None:
1726
return False, True
18-
elif action == flip_ccw:
27+
elif action == FLIP_CW:
28+
return isFlipCWOk(d)
29+
elif action == FLIP_CCW:
1930
return isFlipCCWOk(d)
20-
elif action == split:
31+
elif action == SPLIT:
2132
return isSplitOk(d)
22-
elif action == collapse:
33+
elif action == COLLAPSE:
2334
return isCollapseOk(d)
24-
elif action == cleanup:
35+
elif action == CLEANUP:
2536
return isCleanupOk(d)
26-
elif action == test_all:
37+
elif action == TEST_ALL:
2738
topo, geo = isFlipCCWOk(d)
39+
if not (topo and geo):
40+
return False, False
41+
topo, geo = isFlipCWOk(d)
2842
if not (topo and geo):
2943
return False, False
3044
topo, geo = isSplitOk(d)
@@ -35,8 +49,11 @@ def isValidAction(mesh: Mesh, dart_id: int, action: int) -> (bool, bool):
3549
return False, False
3650
elif topo and geo:
3751
return True, True
38-
elif action == one_valid:
52+
elif action == ONE_VALID:
3953
topo_flip, geo_flip = isFlipCCWOk(d)
54+
if (topo_flip and geo_flip):
55+
return True, True
56+
topo_flip, geo_flip = isFlipCWOk(d)
4057
if (topo_flip and geo_flip):
4158
return True, True
4259
topo_split, geo_split = isSplitOk(d)
@@ -54,30 +71,26 @@ def isFlipCCWOk(d: Dart) -> (bool, bool):
5471
mesh = d.mesh
5572
topo = True
5673
geo = True
74+
5775
# if d is on boundary, flip is not possible
5876
if d.get_beta(2) is None:
5977
topo = False
6078
return topo, geo
6179
else:
6280
d2, d1, d11, d111, d21, d211, d2111, n1, n2, n3, n4, n5, n6 = mesh.active_quadrangles(d)
63-
# if degree are
81+
82+
# if degree will not too high
6483
if not test_degree(n5) or not test_degree(n3):
6584
topo = False
6685
return topo, geo
6786

87+
# if two faces share two edges
6888
if d211.get_node() == d111.get_node() or d11.get_node() == d2111.get_node():
6989
topo = False
7090
return topo, geo
71-
topo = isValidQuad(n5, n6, n2, n3) and isValidQuad(n1, n5, n3, n4)
72-
73-
"""
74-
# Check angle at d limits to avoid edge reversal
75-
angle_A = get_angle_by_coord(n5.x(), n5.y(), n1.x(), n1.y(), n3.x(), n3.y())
7691

77-
if angle_A <= 90 or angle_A >= 180:
78-
topo = False
79-
return topo, geo
80-
"""
92+
# check validity of the two modified quads
93+
geo = isValidQuad(n5, n6, n2, n3) and isValidQuad(n1, n5, n3, n4)
8194

8295
return topo, geo
8396

@@ -99,7 +112,7 @@ def isFlipCWOk(d: Dart) -> (bool, bool):
99112
if d211.get_node() == d111.get_node() or d11.get_node() == d2111.get_node():
100113
topo = False
101114
return topo, geo
102-
topo = isValidQuad(n4, n6, n2, n3) and isValidQuad(n1, n5, n6, n4)
115+
geo = isValidQuad(n4, n6, n2, n3) and isValidQuad(n1, n5, n6, n4)
103116

104117
return topo, geo
105118

@@ -123,7 +136,7 @@ def isSplitOk(d: Dart) -> (bool, bool):
123136
return topo, geo
124137

125138
n10 = mesh.add_node((n1.x() + n2.x()) / 2, (n1.y() + n2.y()) / 2)
126-
topo = isValidQuad(n4, n1, n5, n10) and isValidQuad(n4, n10, n2, n3) and isValidQuad(n10, n5, n6, n2)
139+
geo = isValidQuad(n4, n1, n5, n10) and isValidQuad(n4, n10, n2, n3) and isValidQuad(n10, n5, n6, n2)
127140
mesh.del_node(n10)
128141
return topo, geo
129142

@@ -142,7 +155,7 @@ def isCollapseOk(d: Dart) -> (bool, bool):
142155
topo = False
143156
return topo, geo
144157

145-
if not test_degree(n3) and not test_degree(n1):
158+
if (degree(n3)+degree(n1)-2) > 10:
146159
topo = False
147160
return topo, geo
148161

@@ -198,7 +211,7 @@ def isCollapseOk(d: Dart) -> (bool, bool):
198211
D=n10
199212

200213
if not isValidQuad(A, B, C, D):
201-
topo = False
214+
geo = False
202215
mesh.del_node(n10)
203216
return topo, geo
204217

mesh_model/random_quadmesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def random_mesh() -> Mesh:
1515
:param num_nodes_max: number of nodes of the final mesh
1616
:return: a random mesh
1717
"""
18-
filename = os.path.join(os.path.dirname(__file__), '../mesh_files/t1_quad.msh')
18+
filename = os.path.join(os.path.dirname(__file__), '../mesh_files/simple_quad.msh')
1919
#filename = os.path.join('../mesh_files/', 't1_quad.msh')
2020
#mesh = read_gmsh("/home/ropercha/PycharmProjects/tune/mesh_files/t1_quad.msh")
2121
mesh = read_gmsh(filename)

model_RL/PPO_model_pers.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from mesh_model.mesh_analysis.global_mesh_analysis import global_score
22
import copy
33
import random
4+
import json
45
from tqdm import tqdm
56
import numpy as np
67
import torch
@@ -58,8 +59,8 @@ def select_action(self, observation, info):
5859
action = dist.sample()
5960
action = action.tolist()
6061
prob = pmf[action]
61-
action_dart = int(action/3)
62-
action_type = action % 3
62+
action_dart = int(action/4)
63+
action_type = action % 4
6364
dart_id = info["darts_selected"][action_dart]
6465
i = 0
6566
while not isValidAction(info["mesh"], dart_id, action_type):
@@ -70,8 +71,8 @@ def select_action(self, observation, info):
7071
action = dist.sample()
7172
action = action.tolist()
7273
prob = pmf[action]
73-
action_dart = int(action/3)
74-
action_type = action % 3
74+
action_dart = int(action/4)
75+
action_type = action % 4
7576
dart_id = info["darts_selected"][action_dart]
7677
i += 1
7778
action_list = [action, dart_id, action_type]
@@ -139,7 +140,7 @@ def learn(self, critic_loss):
139140
class PPO:
140141
def __init__(self, env, lr, gamma, nb_iterations, nb_episodes_per_iteration, nb_epochs, batch_size):
141142
self.env = env
142-
self.actor = Actor(env, 10*8, 3*10, lr=0.0001)
143+
self.actor = Actor(env, 10*8, 4*10, lr=0.0001)
143144
self.critic = Critic(8*10, lr=0.0001)
144145
self.lr = lr
145146
self.gamma = gamma
@@ -165,16 +166,14 @@ def train(self, dataset):
165166
critic_loss = []
166167
actor_loss = []
167168
self.critic.optimizer.zero_grad()
168-
G = 0
169-
for _, (s, o, a, r, old_prob, next_o, done) in enumerate(batch, 1):
169+
for _, (s, o, a, r, G, old_prob, next_o, done) in enumerate(batch, 1):
170170
o = torch.tensor(o.flatten(), dtype=torch.float32)
171171
next_o = torch.tensor(next_o.flatten(), dtype=torch.float32)
172172
value = self.critic(o)
173173
pmf = self.actor.forward(o)
174174
log_prob = torch.log(pmf[a[0]])
175175
next_value = torch.tensor(0.0, dtype=torch.float32) if done else self.critic(next_o)
176176
delta = r + 0.9 * next_value - value
177-
G = (r + 0.9 * G) / 10
178177
_, st, ideal_s, _ = global_score(s) # Comparaison à l'état s et pas s+1 ?
179178
if st == ideal_s:
180179
continue
@@ -221,6 +220,7 @@ def learn(self, writer):
221220
ep_reward = 0
222221
ep_mesh_reward = 0
223222
ideal_reward = info["mesh_ideal_rewards"]
223+
G = 0
224224
done = False
225225
step = 0
226226
while step < 40:
@@ -230,20 +230,21 @@ def learn(self, writer):
230230
if action is None:
231231
wins.append(0)
232232
break
233-
gym_action = [action[2],int(action[0]/3)]
233+
gym_action = [action[2],int(action[0]/4)]
234234
next_obs, reward, terminated, truncated, info = self.env.step(gym_action)
235235
ep_reward += reward
236236
ep_mesh_reward += info["mesh_reward"]
237+
G = info["mesh_reward"] + 0.9 * G
237238
if terminated:
238239
if truncated:
239240
wins.append(0)
240-
trajectory.append((state, obs, action, reward, prob, next_obs, done))
241+
trajectory.append((state, obs, action, reward, G, prob, next_obs, done))
241242
else:
242243
wins.append(1)
243244
done = True
244-
trajectory.append((state, obs, action, reward, prob, next_obs, done))
245+
trajectory.append((state, obs, action, reward, G, prob, next_obs, done))
245246
break
246-
trajectory.append((state, obs, action, reward, prob, next_obs, done))
247+
trajectory.append((state, obs, action, reward, G, prob, next_obs, done))
247248
step += 1
248249
if len(trajectory) != 0:
249250
rewards.append(ep_reward)
@@ -252,7 +253,8 @@ def learn(self, writer):
252253
len_ep.append(len(trajectory))
253254
nb_episodes += 1
254255
writer.add_scalar("episode_reward", ep_reward, nb_episodes)
255-
writer.add_scalar("normalized return", (ep_reward/ideal_reward), nb_episodes)
256+
writer.add_scalar("episode_mesh_reward", ep_mesh_reward, nb_episodes)
257+
writer.add_scalar("normalized return", (ep_mesh_reward/ideal_reward), nb_episodes)
256258
writer.add_scalar("len_episodes", len(trajectory), nb_episodes)
257259

258260
self.train(dataset)
@@ -263,4 +265,5 @@ def learn(self, writer):
263265
except NaNExceptionCritic:
264266
print("NaN Exception on Critic Network")
265267
return None, None, None, None
266-
return self.actor, rewards, wins, len_ep
268+
269+
return self.actor, rewards, wins, len_ep, info["observation_count"]

model_RL/parameters/ppo_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
"learning_rate": 0.0001,
77
"gamma": 0.9,
88
"verbose": 1,
9-
"tensorboard_log": "./results/quad/",
9+
"tensorboard_log": "training/results/quad/",
1010
"total_timesteps": 80000
1111
}

0 commit comments

Comments
 (0)