Skip to content

Commit 3698b57

Browse files
committed
New geometrical analysis for triangular meshes.
A geometrical quality criterion is now included in the observation. This is so far only supported in a trimesh_full environment with action flip.
1 parent b79642e commit 3698b57

File tree

12 files changed

+471
-157
lines changed

12 files changed

+471
-157
lines changed

environment/actions/triangular_actions.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ def flip_edge(mesh_analysis, n1: Node, n2: Node) -> True:
3030
found, d = mesh_analysis.mesh.find_inner_edge(n1, n2)
3131
mesh_before = deepcopy(mesh_analysis.mesh)
3232
if found:
33-
topo, geo = mesh_analysis.isSplitOk(d)
33+
topo, geo = mesh_analysis.isFlipOk(d)
3434
if not geo or not topo:
3535
return False, topo, geo
3636
else:
37-
return False, False, True # the geometrical criteria is True because if the dart is not found, it means it's a boundary dart
37+
return False, False, True # the geometrical criteria is True because if the dart is not found, it means it's a boundary dart and a topological criteria
3838

3939

4040
d2, d1, d11, d21, d211, n1, n2, n3, n4 = mesh_analysis.mesh.active_triangles(d)
@@ -65,10 +65,24 @@ def flip_edge(mesh_analysis, n1: Node, n2: Node) -> True:
6565
d211.set_face(f1)
6666
d11.set_face(f2)
6767

68-
topo = check_mesh(mesh_analysis, mesh_before)
69-
if not topo:
70-
mesh_analysis.mesh = deepcopy(mesh_before)
71-
valid_action = False
68+
#Update dart quality and nodes scores
69+
d.set_quality(mesh_analysis.get_dart_geometric_quality(d))
70+
d1.set_quality(mesh_analysis.get_dart_geometric_quality(d1))
71+
d11.set_quality(mesh_analysis.get_dart_geometric_quality(d11))
72+
d21.set_quality(mesh_analysis.get_dart_geometric_quality(d21))
73+
d211.set_quality(mesh_analysis.get_dart_geometric_quality(d211))
74+
75+
n1.set_score(n1.get_score() + 1)
76+
n2.set_score(n2.get_score() + 1)
77+
n3.set_score(n3.get_score() - 1)
78+
n4.set_score(n4.get_score() - 1)
79+
80+
after_check = check_mesh(mesh_analysis, mesh_before)
81+
if not after_check:
82+
raise ValueError("Some checks are missing")
83+
# if not topo:
84+
# mesh_analysis.mesh = deepcopy(mesh_before)
85+
# valid_action = False
7286
return valid_action, topo, geo
7387

7488

@@ -222,6 +236,7 @@ def check_mesh(mesh_analysis, mesh_before=None) -> bool:
222236
# if beta2 relation is not symetrical
223237
elif d2 >= 0 and mesh_analysis.mesh.dart_info[d2, 2] != d:
224238
return False
239+
225240
# null dart
226241
elif d2>=0 and mesh_analysis.mesh.dart_info[d2, 3] == mesh_analysis.mesh.dart_info[d, 3]:
227242
return False
@@ -236,13 +251,18 @@ def check_mesh(mesh_analysis, mesh_before=None) -> bool:
236251
#Check beta1
237252
if mesh_analysis.mesh.dart_info[d11,1]!=d :
238253
return False
254+
#check if the quality is the same for twin darts
255+
if d2>=0 and mesh_analysis.mesh.dart_info[d2,5] != mesh_analysis.mesh.dart_info[d, 5]:
256+
plot_mesh(mesh_analysis.mesh)
257+
return False
239258

240259
if d2 >= 0 :
241260
d = Dart(mesh_analysis.mesh, d)
242261
d2, d1, d11, d21, d211, n1, n2, n3, n4 = mesh_analysis.mesh.active_triangles(d)
243-
if len(set([n1.id, n2.id, n3.id, n4.id])) < 4:
262+
if len(set([n1.id, n2.id, n3.id, n4.id])) < 4 and d.get_quality() != 3: # not flat faces
263+
plot_mesh(mesh_analysis.mesh)
244264
return False
245-
return True
265+
return True
246266

247267

248268
def check_mesh_debug(mesh_analysis, mesh_before=None)->True:

environment/gymnasium_envs/quadmesh_env/wrappers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@
22
from environment.gymnasium_envs.quadmesh_env.wrappers.discrete_actions import DiscreteActions
33
from environment.gymnasium_envs.quadmesh_env.wrappers.reacher_weighted_reward import ReacherRewardWrapper
44
from environment.gymnasium_envs.quadmesh_env.wrappers.relative_position import RelativePosition
5-
from environment.gymnasium_envs.quadmesh_env.wrappers.dummy_vertices_observation import

environment/gymnasium_envs/trimesh_full_env/envs/mesh_conv.py

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import numpy as np
2-
from mesh_model.mesh_analysis.trimesh_analysis import TriMeshGeoAnalysis, TriMeshTopoAnalysis
3-
from mesh_model.mesh_struct.mesh_elements import Dart
2+
from mesh_model.mesh_struct.mesh_elements import Dart, Node
43
from mesh_model.mesh_struct.mesh import Mesh
54

65

76
def get_x(m_analysis, n_darts_selected: int, deep :int, degree: bool, restricted:bool, nodes_scores: list[int], nodes_adjacency: list[int]):
87
mesh = m_analysis.mesh
98
if degree:
10-
template, darts_id = get_template_deg(m_analysis, deep, nodes_scores, nodes_adjacency)
9+
template, darts_id = get_template_with_quality(m_analysis, deep)
1110
else:
1211
template, darts_id = get_template(m_analysis, deep, nodes_scores)
1312

@@ -48,58 +47,71 @@ def get_template(m_analysis, deep: int, nodes_scores):
4847
C = d11.get_node()
4948

5049
# Template niveau 1
51-
template[n_darts - 1, 0] = nodes_scores[C.id]
52-
template[n_darts - 1, 1] = nodes_scores[A.id]
53-
template[n_darts - 1, 2] = nodes_scores[B.id]
50+
template[n_darts - 1, 0] = C.get_score()
51+
template[n_darts - 1, 1] = A.get_score()
52+
template[n_darts - 1, 2] = B.get_score()
5453

5554
if deep>3:
5655
# template niveau 2 deep = 6
5756
n_id = m_analysis.find_template_opposite_node(d)
5857
if n_id is not None:
59-
template[n_darts - 1, 3] = nodes_scores[n_id]
58+
template[n_darts - 1, 3] = n_id.get_score()
6059
n_id = m_analysis.find_template_opposite_node(d1)
6160
if n_id is not None:
62-
template[n_darts - 1, 4] = nodes_scores[n_id]
61+
template[n_darts - 1, 4] = n_id.get_score()
6362
n_id = m_analysis.find_template_opposite_node(d11)
6463
if n_id is not None:
65-
template[n_darts - 1, 5] = nodes_scores[n_id]
64+
template[n_darts - 1, 5] = n_id.get_score()
6665

6766
if deep>6:
6867
# template niveau 3 - deep = 12
6968
d2, d1, d11, d21, d211, n1, n2, n3, n4 = m_analysis.mesh.active_triangles(d)
7069
#Triangle F2
7170
n_id = m_analysis.find_template_opposite_node(d21)
7271
if n_id is not None:
73-
template[n_darts - 1, 6] = nodes_scores[n_id]
72+
template[n_darts - 1, 6] = n_id.get_score()
7473
n_id = m_analysis.find_template_opposite_node(d211)
7574
if n_id is not None:
76-
template[n_darts - 1, 7] = nodes_scores[n_id]
75+
template[n_darts - 1, 7] = n_id.get_score()
7776
# Triangle T3
7877
d12 = d1.get_beta(2)
7978
d121 = d12.get_beta(1)
8079
d1211 = d121.get_beta(1)
8180
n_id = m_analysis.find_template_opposite_node(d121)
8281
if n_id is not None:
83-
template[n_darts - 1, 8] = nodes_scores[n_id]
82+
template[n_darts - 1, 8] = n_id.get_score()
8483
n_id = m_analysis.find_template_opposite_node(d1211)
8584
if n_id is not None:
86-
template[n_darts - 1, 9] = nodes_scores[n_id]
85+
template[n_darts - 1, 9] = n_id.get_score()
8786
# Triangle T4
8887
d112 = d11.get_beta(2)
8988
d1121 = d112.get_beta(1)
9089
d11211 = d1121.get_beta(1)
9190
n_id = m_analysis.find_template_opposite_node(d1121)
9291
if n_id is not None:
93-
template[n_darts - 1, 10] = nodes_scores[n_id]
92+
template[n_darts - 1, 10] = n_id.get_score()
9493
n_id = m_analysis.find_template_opposite_node(d11211)
9594
if n_id is not None:
96-
template[n_darts - 1, 11] = nodes_scores[n_id]
95+
template[n_darts - 1, 11] = n_id.get_score()
9796

9897
template = template[:n_darts, :]
9998

10099
return template, dart_ids
101100

102-
def get_template_deg(m_analysis, deep: int, nodes_scores, nodes_adjacency):
101+
def get_template_with_quality(m_analysis, deep: int):
102+
"""
103+
Create a template matrix representing the entire mesh, composed of:
104+
105+
* Node scores: i.e. the difference between ideal adjacency and actual adjacency.
106+
* Dart surrounding quality: a measure of the geometric quality around each dart.
107+
108+
Each column in the matrix corresponds to the local surrounding of a dart,
109+
including the scores of its surrounding nodes and its associated quality.
110+
111+
:param m_analysis: mesh to analyze
112+
:param deep: observation deep (how many nodes observed on each dart surrounding)
113+
:return: template matrix
114+
"""
103115
size = len(m_analysis.mesh.dart_info)
104116
template = np.zeros((size, deep*2), dtype=np.int64)
105117
dart_ids = []
@@ -117,27 +129,30 @@ def get_template_deg(m_analysis, deep: int, nodes_scores, nodes_adjacency):
117129
C = d11.get_node()
118130

119131
# Template niveau 1
120-
template[n_darts - 1, 0] = nodes_scores[C.id]
121-
template[n_darts - 1, deep] = nodes_adjacency[C.id]
122-
template[n_darts - 1, 1] = nodes_scores[A.id]
123-
template[n_darts - 1, deep+1] = nodes_adjacency[A.id]
124-
template[n_darts - 1, 2] = nodes_scores[B.id]
125-
template[n_darts - 1, deep+2] = nodes_adjacency[B.id]
132+
template[n_darts - 1, 0] = C.get_score()
133+
template[n_darts - 1, deep] = d.get_quality()
134+
template[n_darts - 1, 1] = A.get_score()
135+
template[n_darts - 1, deep+1] = d1.get_quality()
136+
template[n_darts - 1, 2] = B.get_score()
137+
template[n_darts - 1, deep+2] = d11.get_quality()
126138

127139
if deep>3:
128140
# template niveau 2
129141
n_id = m_analysis.find_template_opposite_node(d)
130142
if n_id is not None:
131-
template[n_darts - 1, 3] = nodes_scores[n_id]
132-
template[n_darts - 1, deep+3] = nodes_adjacency[n_id]
143+
n = Node(m_analysis.mesh, n_id)
144+
template[n_darts - 1, 3] = n.get_score()
145+
template[n_darts - 1, deep+3] = d.get_quality() #quality around dart d is equivalent to quality around dart d2
133146
n_id = m_analysis.find_template_opposite_node(d1)
134147
if n_id is not None:
135-
template[n_darts - 1, 4] = nodes_scores[n_id]
136-
template[n_darts - 1, deep+4] = nodes_adjacency[n_id]
148+
n = Node(m_analysis.mesh, n_id)
149+
template[n_darts - 1, 3] = n.get_score()
150+
template[n_darts - 1, deep+4] = d1.get_quality()
137151
n_id = m_analysis.find_template_opposite_node(d11)
138152
if n_id is not None:
139-
template[n_darts - 1, 5] = nodes_scores[n_id]
140-
template[n_darts - 1, deep+5] = nodes_adjacency[n_id]
153+
n = Node(m_analysis.mesh, n_id)
154+
template[n_darts - 1, 3] = n.get_score()
155+
template[n_darts - 1, deep+5] = d11.get_quality()
141156

142157
if deep>6:
143158
# template niveau 3 - deep = 12
@@ -146,38 +161,44 @@ def get_template_deg(m_analysis, deep: int, nodes_scores, nodes_adjacency):
146161
#Triangle F2
147162
n_id = m_analysis.find_template_opposite_node(d21)
148163
if n_id is not None:
149-
template[n_darts - 1, 6] = nodes_scores[n_id]
150-
template[n_darts - 1, deep+6] = nodes_adjacency[n_id]
164+
n = Node(m_analysis.mesh, n_id)
165+
template[n_darts - 1, 3] = n.get_score()
166+
template[n_darts - 1, deep+6] = d21.get_quality()
151167
n_id = m_analysis.find_template_opposite_node(d211)
152168
if n_id is not None:
153-
template[n_darts - 1, 7] = nodes_scores[n_id]
154-
template[n_darts - 1, deep+7] = nodes_adjacency[n_id]
169+
n = Node(m_analysis.mesh, n_id)
170+
template[n_darts - 1, 3] = n.get_score()
171+
template[n_darts - 1, deep+7] = d211.get_quality()
155172
# Triangle T3
156173
d12 = d1.get_beta(2)
157174
if d12 is not None:
158175
d121 = d12.get_beta(1)
159176
d1211 = d121.get_beta(1)
160177
n_id = m_analysis.find_template_opposite_node(d121)
161178
if n_id is not None:
162-
template[n_darts - 1, 8] = nodes_scores[n_id]
163-
template[n_darts - 1, deep+8] = nodes_adjacency[n_id]
179+
n = Node(m_analysis.mesh, n_id)
180+
template[n_darts - 1, 3] = n.get_score()
181+
template[n_darts - 1, deep+8] = d121.get_quality()
164182
n_id = m_analysis.find_template_opposite_node(d1211)
165183
if n_id is not None:
166-
template[n_darts - 1, 9] = nodes_scores[n_id]
167-
template[n_darts - 1, deep+9] = nodes_adjacency[n_id]
184+
n = Node(m_analysis.mesh, n_id)
185+
template[n_darts - 1, 3] = n.get_score()
186+
template[n_darts - 1, deep+9] = d1211.get_quality()
168187
# Triangle T4
169188
d112 = d11.get_beta(2)
170189
if d112 is not None:
171190
d1121 = d112.get_beta(1)
172191
d11211 = d1121.get_beta(1)
173192
n_id = m_analysis.find_template_opposite_node(d1121)
174193
if n_id is not None:
175-
template[n_darts - 1, 10] = nodes_scores[n_id]
176-
template[n_darts - 1, deep+10] = nodes_adjacency[n_id]
194+
n = Node(m_analysis.mesh, n_id)
195+
template[n_darts - 1, 3] = n.get_score()
196+
template[n_darts - 1, deep+10] = d1121.get_quality()
177197
n_id = m_analysis.find_template_opposite_node(d11211)
178198
if n_id is not None:
179-
template[n_darts - 1, 11] = nodes_scores[n_id]
180-
template[n_darts - 1, deep+11] = nodes_adjacency[n_id]
199+
n = Node(m_analysis.mesh, n_id)
200+
template[n_darts - 1, 3] = n.get_score()
201+
template[n_darts - 1, deep+11] = d11211.get_quality()
181202

182203
template = template[:n_darts, :]
183204
return template, dart_ids

environment/gymnasium_envs/trimesh_full_env/envs/trimesh.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1+
from copy import deepcopy
12
from enum import Enum
23
import gymnasium as gym
34
from gymnasium import spaces
45
import pygame
56
from pygame.locals import *
67
import numpy as np
78
import sys
9+
import imageio
810

911
from mesh_model.random_trimesh import random_mesh
1012
from mesh_model.mesh_struct.mesh_elements import Dart
1113
from mesh_model.mesh_analysis.trimesh_analysis import TriMeshGeoAnalysis, TriMeshTopoAnalysis
1214
from environment.gymnasium_envs.trimesh_full_env.envs.mesh_conv import get_x
13-
from environment.actions.triangular_actions import flip_edge, split_edge, collapse_edge
15+
from environment.actions.triangular_actions import flip_edge, split_edge, collapse_edge, check_mesh
16+
from view.mesh_plotter.mesh_plots import plot_mesh
1417
from view.window import window_data, graph
1518
from mesh_display import MeshDisplay
1619

@@ -27,7 +30,7 @@ class Actions(Enum):
2730
class TriMeshEnvFull(gym.Env):
2831
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 60}
2932

30-
def __init__(self, mesh=None, mesh_size=9, n_darts_selected=20, deep=6, with_degree_obs=False, action_restriction=False, render_mode=None):
33+
def __init__(self, mesh=None, mesh_size=9, max_episode_steps=20, n_darts_selected=20, deep=6, with_degree_obs=False, action_restriction=False, render_mode=None):
3134
self.mesh = mesh if mesh is not None else random_mesh(mesh_size)
3235
self.m_analysis = TriMeshTopoAnalysis(self.mesh)
3336
self.mesh_size = mesh_size
@@ -44,6 +47,9 @@ def __init__(self, mesh=None, mesh_size=9, n_darts_selected=20, deep=6, with_deg
4447
self.nb_invalid_actions = 0
4548
self.darts_selected = [] # darts id observed
4649
deep = self.deep*2 if self.degree_observation else deep
50+
self.episode_count = 0
51+
self.ep_len = 0
52+
self.max_steps = max_episode_steps
4753

4854
self.observation_space = spaces.Box(
4955
low=-15, # nodes min degree : 15
@@ -59,7 +65,6 @@ def __init__(self, mesh=None, mesh_size=9, n_darts_selected=20, deep=6, with_deg
5965

6066
assert render_mode is None or render_mode in self.metadata["render_modes"]
6167
self.render_mode = render_mode
62-
6368
"""
6469
If human-rendering is used, `self.window` will be a reference
6570
to the window that we draw to. `self.clock` will be a clock that is used
@@ -92,10 +97,17 @@ def reset(self, seed=None, options=None):
9297
self._nodes_scores, self._mesh_score, self._ideal_score, self._nodes_adjacency = self.m_analysis.global_score()
9398
self._ideal_rewards = (self._mesh_score - self._ideal_score) * 10
9499
self.nb_invalid_actions = 0
100+
self.ep_len = 0
95101
self.close()
96102
self.observation = self._get_obs()
97103
info = self._get_info(terminated=False,valid_act=(None,None,None), action=(None,None), mesh_reward=None)
98104

105+
"""
106+
if self._ideal_score !=0:
107+
self.render_mode = "human"
108+
else:
109+
self.render_mode = None
110+
"""
99111
if self.render_mode == "human":
100112
self._render_frame()
101113
self.recording = True
@@ -113,6 +125,8 @@ def _get_obs(self):
113125

114126
def _get_info(self, terminated, valid_act, action, mesh_reward):
115127
valid_action, valid_topo, valid_geo = valid_act
128+
if self._mesh_score - self._ideal_score <0:
129+
raise ValueError("score impossible")
116130
return {
117131
"distance": self._mesh_score - self._ideal_score,
118132
"mesh_reward" : mesh_reward,
@@ -139,13 +153,14 @@ def _action_to_dart_id(self, action: np.ndarray) -> int:
139153
return self.darts_selected[int(action[1])]
140154

141155
def step(self, action: np.ndarray):
156+
self.ep_len+=1
142157
dart_id = self._action_to_dart_id(action)
143158
d = Dart(self.mesh, dart_id)
144159
d1 = d.get_beta(1)
145160
n1 = d.get_node()
146161
n2 = d1.get_node()
147162
valid_action, valid_topo, valid_geo = False, False, False
148-
163+
before_mesh = deepcopy(self.mesh)
149164
if action[0] == Actions.FLIP.value:
150165
valid_action, valid_topo, valid_geo = flip_edge(self.m_analysis, n1, n2)
151166
elif action[0] == Actions.SPLIT.value:
@@ -160,6 +175,12 @@ def step(self, action: np.ndarray):
160175
next_nodes_score, self.next_mesh_score, _, next_nodes_adjacency = self.m_analysis.global_score()
161176
terminated = np.array_equal(self._ideal_score, self.next_mesh_score)
162177
mesh_reward = (self._mesh_score - self.next_mesh_score)*10
178+
if mesh_reward == 10:
179+
b_mesh_analysis = TriMeshTopoAnalysis(before_mesh)
180+
plot_mesh(before_mesh)
181+
plot_mesh(self.mesh)
182+
bool1 = check_mesh(b_mesh_analysis)
183+
bool2 = check_mesh(self.m_analysis)
163184
reward = mesh_reward
164185
self._nodes_scores, self._mesh_score, self._nodes_adjacency = next_nodes_score, self.next_mesh_score, next_nodes_adjacency
165186
self.observation = self._get_obs()
@@ -185,6 +206,12 @@ def step(self, action: np.ndarray):
185206

186207
if self.render_mode == "human":
187208
self._render_frame()
209+
#Saving episode rendering as gif
210+
if terminated or self.ep_len>= self.max_steps:
211+
if self.recording and self.frames:
212+
imageio.mimsave(f"training/episode_recording/episode_{self.episode_count}.gif", self.frames, fps=1)
213+
print("Image recorded")
214+
self.episode_count +=1
188215

189216
return self.observation, reward, terminated, truncated, info
190217

0 commit comments

Comments
 (0)