Skip to content

Commit f904e1f

Browse files
authored
Merge pull request #52 from LIHPC-Computational-Geometry/45-RL_Actor_Critic
45 RL environment and model
2 parents f66cba8 + d03f6d8 commit f904e1f

21 files changed

+1047
-77
lines changed

actions/triangular_actions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from model.mesh_struct.mesh import Mesh
44
from model.mesh_struct.mesh_elements import Dart, Node
5-
from model.mesh_analysis import degree
5+
from model.mesh_analysis import degree, isFlipOk
66

77

88
def flip_edge_ids(mesh: Mesh, id1: int, id2: int) -> True:
@@ -11,7 +11,8 @@ def flip_edge_ids(mesh: Mesh, id1: int, id2: int) -> True:
1111

1212
def flip_edge(mesh: Mesh, n1: Node, n2: Node) -> True:
1313
found, d = mesh.find_inner_edge(n1, n2)
14-
if not found:
14+
15+
if not found or not isFlipOk(d):
1516
return False
1617

1718
d2, d1, d11, d21, d211, n1, n2, n3, n4 = active_triangles(mesh, d)
@@ -113,5 +114,5 @@ def test_degree(n: Node) -> bool:
113114
"""
114115
if degree(n) > 10:
115116
return False
116-
117-
117+
else:
118+
return True

environment/__init__.py

Whitespace-only changes.

environment/trimesh_env.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import Any
2+
import numpy as np
3+
from model.mesh_analysis import global_score, isValidAction, find_template_opposite_node
4+
from model.mesh_struct.mesh_elements import Dart
5+
from model.mesh_struct.mesh import Mesh
6+
from actions.triangular_actions import flip_edge
7+
from model.random_trimesh import random_flip_mesh
8+
9+
# possible actions
10+
FLIP = 0
11+
GLOBAL = 0
12+
13+
14+
class TriMesh:
15+
def __init__(self, mesh=None, mesh_size: int = None, max_steps: int = 50, feat: int = 0):
16+
self.mesh = mesh if mesh is not None else random_flip_mesh(mesh_size)
17+
self.mesh_size = len(self.mesh.nodes)
18+
self.size = len(self.mesh.dart_info)
19+
self.actions = np.array([FLIP])
20+
self.reward = 0
21+
self.steps = 0
22+
self.max_steps = max_steps
23+
self.nodes_scores = global_score(self.mesh)[0]
24+
self.ideal_score = global_score(self.mesh)[2]
25+
self.terminal = False
26+
self.feat = feat
27+
self.won = 0
28+
29+
def reset(self, mesh=None):
30+
self.reward = 0
31+
self.steps = 0
32+
self.terminal = False
33+
self.mesh = mesh if mesh is not None else random_flip_mesh(self.mesh_size)
34+
self.size = len(self.mesh.dart_info)
35+
self.nodes_scores = global_score(self.mesh)[0]
36+
self.ideal_score = global_score(self.mesh)[2]
37+
self.won = 0
38+
39+
def step(self, action):
40+
dart_id = action[1]
41+
_, mesh_score, mesh_ideal_score = global_score(self.mesh)
42+
d = Dart(self.mesh, dart_id)
43+
d1 = d.get_beta(1)
44+
n1 = d.get_node()
45+
n2 = d1.get_node()
46+
flip_edge(self.mesh, n1, n2)
47+
self.steps += 1
48+
next_nodes_score, next_mesh_score, _ = global_score(self.mesh)
49+
self.nodes_scores = next_nodes_score
50+
self.reward = (mesh_score - next_mesh_score)*10
51+
if self.steps >= self.max_steps or next_mesh_score == mesh_ideal_score:
52+
if next_mesh_score == mesh_ideal_score:
53+
self.won = True
54+
self.terminal = True
55+
56+
def get_x(self, s: Mesh, a: int) -> tuple[Any, list[int | list[int]]]:
57+
"""
58+
Get the feature vector of the state-action pair
59+
:param s: the state
60+
:param a: the action
61+
:return: the feature vector and valid darts id
62+
"""
63+
if s is None:
64+
s = self.mesh
65+
if self.feat == GLOBAL:
66+
return get_x_global_4(self, s)
67+
68+
69+
def get_x_global_4(env, state: Mesh) -> tuple[Any, list[int | list[int]]]:
70+
"""
71+
Get the feature vector of the state.
72+
:param state: the state
73+
:param env: The environment
74+
:return: the feature vector
75+
"""
76+
mesh = state
77+
nodes_scores = global_score(mesh)[0]
78+
size = len(mesh.dart_info)
79+
template = np.zeros((size, 6))
80+
81+
for d_info in mesh.dart_info:
82+
83+
d = Dart(mesh, d_info[0])
84+
A = d.get_node()
85+
d1 = d.get_beta(1)
86+
B = d1.get_node()
87+
d11 = d1.get_beta(1)
88+
C = d11.get_node()
89+
90+
#Template niveau 1
91+
template[d_info[0], 0] = nodes_scores[C.id]
92+
template[d_info[0], 1] = nodes_scores[A.id]
93+
template[d_info[0], 2] = nodes_scores[B.id]
94+
95+
#template niveau 2
96+
97+
n_id = find_template_opposite_node(d)
98+
if n_id is not None:
99+
template[d_info[0], 3] = nodes_scores[n_id]
100+
n_id = find_template_opposite_node(d1)
101+
if n_id is not None:
102+
template[d_info[0], 4] = nodes_scores[n_id]
103+
n_id = find_template_opposite_node(d11)
104+
if n_id is not None:
105+
template[d_info[0], 5] = nodes_scores[n_id]
106+
107+
dart_to_delete = []
108+
dart_ids = []
109+
for i in range(size):
110+
d = Dart(mesh, i)
111+
if not isValidAction(mesh, d.id):
112+
dart_to_delete.append(i)
113+
else :
114+
dart_ids.append(i)
115+
valid_template = np.delete(template, dart_to_delete, axis=0)
116+
score_sum = np.sum(np.abs(valid_template), axis=1)
117+
indices_top_10 = np.argsort(score_sum)[-5:][::-1]
118+
valid_dart_ids = [dart_ids[i] for i in indices_top_10]
119+
X = valid_template[indices_top_10, :]
120+
X = X.flatten()
121+
return X, valid_dart_ids

main.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,13 @@
1-
from view.window import Game
2-
from model.mesh_struct.mesh import Mesh
3-
from mesh_display import MeshDisplay
4-
import model.random_trimesh as TM
5-
#import model.reader as Reader
6-
71
import sys
8-
import json
2+
3+
from user_game import user_game
4+
from train import train
95

106

117
# Press the green button in the gutter to run the script.
128
if __name__ == '__main__':
139

14-
if len(sys.argv) != 2:
15-
print("Usage: main.py <nb_nodes_of_the_mesh>")
16-
else:
17-
cmap = TM.random_mesh(int(sys.argv[1]))
18-
mesh_disp = MeshDisplay(cmap)
19-
g = Game(cmap, mesh_disp)
20-
g.run()
21-
22-
"""
23-
#Code to load a json file and create a mesh
24-
25-
if len(sys.argv) != 2:
26-
print("Usage: main.py <mesh_file.json>")
27-
else:
28-
f = open(sys.argv[1])
29-
json_mesh = json.load(f)
30-
cmap = Mesh(json_mesh['nodes'], json_mesh['faces'])
31-
mesh_disp = MeshDisplay(cmap)
32-
g = Game(cmap, mesh_disp)
33-
g.run()
34-
"""
35-
"""
36-
# Code to load a Medit .mesh file and create a mesh
37-
if len(sys.argv) != 2:
38-
print("Usage: main.py <mesh_file.mesh>")
39-
else:
40-
cmap = Reader.read_medit(sys.argv[1])
41-
mesh_disp = MeshDisplay(cmap)
42-
g = Game(cmap, mesh_disp)
43-
g.run()
44-
"""
45-
"""
46-
# Code to load a gmsh .msh file and create a mesh
47-
if len(sys.argv) != 2:
48-
print("Usage: main.py <mesh_file.msh>")
10+
if len(sys.argv) == 2:
11+
user_game(int(sys.argv[1]))
4912
else:
50-
cmap = Reader.read_gmsh(sys.argv[1])
51-
mesh_disp = MeshDisplay(cmap)
52-
g = Game(cmap, mesh_disp)
53-
g.run()
54-
"""
13+
train()

model/mesh_analysis.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from math import sqrt, degrees, radians, cos, sin
1+
from math import sqrt, degrees, radians, cos, sin, acos
22
import numpy as np
33

44
from model.mesh_struct.mesh_elements import Dart, Node
@@ -113,9 +113,9 @@ def adjacent_darts(n: Node) -> list[Dart]:
113113
d = Dart(n.mesh, d_info[0])
114114
d_nfrom = d.get_node()
115115
d_nto = d.get_beta(1)
116-
if d_nfrom == n:
116+
if d_nfrom == n and d not in adj_darts:
117117
adj_darts.append(d)
118-
if d_nto.get_node() == n:
118+
if d_nto.get_node() == n and d not in adj_darts:
119119
adj_darts.append(d)
120120
return adj_darts
121121

@@ -137,7 +137,8 @@ def degree(n: Node) -> int:
137137
boundary_darts.append(d)
138138
else:
139139
adjacency += 0.5
140-
140+
if adjacency != int(adjacency):
141+
raise ValueError("Adjacency error")
141142
return adjacency
142143

143144

@@ -191,3 +192,78 @@ def find_opposite_node(d: Dart) -> (int, int):
191192
y_C = A.y() + y_AC
192193

193194
return x_C, y_C
195+
196+
def find_template_opposite_node(d: Dart) -> (int):
197+
"""
198+
Find the the vertex opposite in the adjacent triangle
199+
:param d: a dart
200+
:return: the node found
201+
"""
202+
203+
d2 = d.get_beta(2)
204+
if d2 is not None:
205+
d21 = d2.get_beta(1)
206+
d211 = d21.get_beta(1)
207+
node_opposite = d211.get_node()
208+
return node_opposite.id
209+
else:
210+
return None
211+
212+
213+
def node_in_mesh(mesh: Mesh, x: float, y: float) -> (bool, int):
214+
"""
215+
Search if the node of coordinate (x, y) is inside the mesh.
216+
:param mesh: the mesh to work with
217+
:param x: X coordinate
218+
:param y: Y coordinate
219+
:return: a boolean indicating if the node is inside the mesh and the id of the node if it is.
220+
"""
221+
n_id = 0
222+
for n in mesh.nodes:
223+
if abs(x - n[0]) <= 0.1 and abs(y - n[1]) <= 0.1:
224+
return True, n_id
225+
n_id = n_id + 1
226+
return False, None
227+
228+
229+
def isValidAction(mesh, dart_id: int) -> bool:
230+
d = Dart(mesh, dart_id)
231+
boundary_darts = get_boundary_darts(mesh)
232+
if d in boundary_darts or not isFlipOk(d):
233+
return False
234+
else:
235+
return True
236+
237+
def get_angle_by_coord(x1: float, y1: float, x2: float, y2: float, x3:float, y3:float) -> float:
238+
BAx, BAy = x1 - x2, y1 - y2
239+
BCx, BCy = x3 - x2, y3 - y2
240+
241+
cos_ABC = (BAx * BCx + BAy * BCy) / (sqrt(BAx ** 2 + BAy ** 2) * sqrt(BCx ** 2 + BCy ** 2))
242+
243+
rad = acos(cos_ABC)
244+
deg = degrees(rad)
245+
return deg
246+
247+
248+
def isFlipOk(d:Dart) -> bool:
249+
d1 = d.get_beta(1)
250+
d11 = d1.get_beta(1)
251+
A = d.get_node()
252+
B = d1.get_node()
253+
C = d11.get_node()
254+
d2 = d.get_beta(2)
255+
if d2 is None:
256+
return False
257+
else:
258+
d21 = d2.get_beta(1)
259+
d211 = d21.get_beta(1)
260+
D = d211.get_node()
261+
262+
# Calcul angle at d limits
263+
angle_B = get_angle_by_coord(A.x(), A.y(), B.x(), B.y(), C.x(), C.y()) + get_angle_by_coord(A.x(), A.y(), B.x(), B.y(), D.x(), D.y())
264+
angle_A = get_angle_by_coord(B.x(), B.y(), A.x(), A.y(), C.x(), C.y()) + get_angle_by_coord(B.x(), B.y(), A.x(), A.y(), D.x(), D.y())
265+
266+
if angle_B >= 180 or angle_A >= 180:
267+
return False
268+
else:
269+
return True

0 commit comments

Comments
 (0)