Skip to content

Commit 2c4feca

Browse files
committed
Add a register to count the number of times an observation is seen by the RL agent.
Ensure that the quadmesh and trimesh environments are functional with SB3 PPO and personal PPO.
1 parent 305bda5 commit 2c4feca

File tree

15 files changed

+408
-147
lines changed

15 files changed

+408
-147
lines changed

environment/environment_config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
"env_name": "Quadmesh-v0",
33
"mesh_size": 16,
44
"max_episode_steps": 30,
5-
"n_darts_selected": 25,
6-
"deep": 6,
5+
"n_darts_selected": 10,
6+
"deep": 8,
77
"action_restriction": false,
88
"with_degree_observation": false
99
}

environment/gymnasium_envs/quadmesh_env/envs/quadmesh.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mesh_model.mesh_analysis.quadmesh_analysis import global_score, isTruncated
1010
from environment.gymnasium_envs.quadmesh_env.envs.mesh_conv import get_x
1111
from environment.actions.quadrangular_actions import flip_edge, split_edge, collapse_edge, cleanup_edge
12+
from environment.observation_register import ObservationRegistry
1213

1314

1415
class Actions(Enum):
@@ -21,7 +22,7 @@ class Actions(Enum):
2122
class QuadMeshEnv(gym.Env):
2223
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 60}
2324

24-
def __init__(self, mesh=None, n_darts_selected=20, deep=6, with_degree_obs=True, action_restriction=False, render_mode=None):
25+
def __init__(self, mesh=None, max_episode_steps=30, n_darts_selected=20, deep=6, with_degree_obs=True, action_restriction=False, render_mode=None):
2526
if mesh is not None:
2627
self.config = {"mesh": mesh}
2728
self.mesh = copy.deepcopy(mesh)
@@ -33,21 +34,22 @@ def __init__(self, mesh=None, n_darts_selected=20, deep=6, with_degree_obs=True,
3334
self._nodes_scores, self._mesh_score, self._ideal_score, self._nodes_adjacency = global_score(self.mesh)
3435
self._ideal_rewards = (self._mesh_score - self._ideal_score)*10
3536
self.next_mesh_score = 0
36-
self.deep = deep
3737
self.n_darts_selected = n_darts_selected
3838
self.restricted = action_restriction
3939
self.degree_observation = with_degree_obs
4040
self.window_size = 512 # The size of the PyGame window
4141
self.g = None
4242
self.nb_invalid_actions = 0
43+
self.max_steps = max_episode_steps
4344
self.darts_selected = [] # darts id observed
44-
deep = self.deep*2 if self.degree_observation else deep
45+
self.deep = deep*2 if self.degree_observation else deep
4546
self.observation_space = spaces.Box(
46-
low=-15, # nodes min degree : -15
47-
high=15, # nodes max degree : 15
48-
shape=(self.n_darts_selected, self.deep * 2 if self.degree_observation else self.deep),
47+
low=-6, # nodes min degree : -15
48+
high=2, # nodes max degree : 15
49+
shape=(self.n_darts_selected, deep),
4950
dtype=np.int64
5051
)
52+
self.observation_count = ObservationRegistry(self.n_darts_selected, self.deep, -6, 2)
5153

5254
self.observation = None
5355

@@ -102,6 +104,8 @@ def _get_info(self, terminated, valid_act, action, mesh_reward):
102104
"invalid_collapse": 1.0 if action[0]==Actions.COLLAPSE.value and not valid_action else 0.0,
103105
"invalid_cleanup": 1.0 if action[0]==Actions.CLEANUP.value and not valid_action else 0.0,
104106
"mesh" : self.mesh,
107+
"darts_selected" : self.darts_selected,
108+
"observation_count" : self.observation_count,
105109
}
106110

107111
def _action_to_dart_id(self, action: np.ndarray) -> int:
@@ -131,6 +135,7 @@ def step(self, action: np.ndarray):
131135
else:
132136
raise ValueError("Action not defined")
133137

138+
self.observation_count.register_observation(self.observation)
134139
if valid_action:
135140
# An episode is done if the actual score is the same as the ideal
136141
next_nodes_score, self.next_mesh_score, _, next_nodes_adjacency = global_score(self.mesh)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
2+
class ObservationRegistry:
3+
def __init__(self, n_darts_selected, deep, lowest_value, highest_value):
4+
self.n_darts = n_darts_selected
5+
self.deep = deep
6+
self.low = lowest_value
7+
self.high = highest_value
8+
self.counts = {}
9+
10+
def encode(self, observation):
11+
"""
12+
Converts an observation into a tuple.
13+
:param observation:
14+
:return: the tuple ID of the observation
15+
"""
16+
ID = tuple(tuple(dart_surrounding) for dart_surrounding in observation)
17+
return ID
18+
19+
def decode(self, ID):
20+
"""
21+
Reconstructs an observation from its ID.
22+
:param ID: a tuple ID
23+
:return: the observation matrix
24+
"""
25+
return [list(dart_surrounding) for dart_surrounding in ID]
26+
27+
def register_observation(self, observation) -> None:
28+
"""
29+
Adds an observation to the register and increments its counter.
30+
:param observation:
31+
"""
32+
ID = self.encode(observation)
33+
if self.counts.get(ID) is None:
34+
self.counts[ID] = 1
35+
else:
36+
self.counts[ID] += 1
37+
38+
def get_count(self, observation):
39+
"""
40+
Returns the number of times an observation has been seen by the agent.
41+
:param observation:
42+
:return: counts of observations
43+
"""
44+
ID = self.encode(observation)
45+
return self.counts.get(ID, 0)
46+

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22

33
from user_game import user_game
4-
from training.train import train
4+
from training.train_quadmesh import train
55
from training.exploit import exploit
66
#from mesh_model.reader import read_gmsh
77
#from mesh_display import MeshDisplay

mesh_model/mesh_analysis/global_mesh_analysis.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@ def score_calculation(n: Node) -> (int, int):
3939
:return: the irregularity of the node
4040
"""
4141
adjacency = degree(n)
42-
m = n.mesh
4342
# Check if the mesh is triangular or quad
44-
d = Dart(m, m.dart_info[0,0])
45-
triangular = d.get_beta(1).get_beta(1).get_beta(1) == d
43+
d = n.get_dart()
44+
if d.id <0:
45+
raise ValueError("No existing dart")
46+
d1 = d.get_beta(1)
47+
d11 = d1.get_beta(1)
48+
d111 = d11.get_beta(1)
49+
triangular = (d111.id == d.id)
4650
if on_boundary(n):
4751
angle = get_boundary_angle(n)
4852
if triangular:
@@ -85,6 +89,7 @@ def get_angle(d1: Dart, d2: Dart, n: Node) -> float:
8589
cos_theta = np.clip(cos_theta, -1, 1)
8690
angle = np.arccos(cos_theta)
8791
if np.isnan(angle):
92+
plot_mesh(n.mesh)
8893
raise ValueError("Angle error")
8994
return degrees(angle)
9095

mesh_model/mesh_analysis/quadmesh_analysis.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,14 @@ def isFlipOk(d: Dart) -> (bool, bool):
104104
return topo, geo
105105
else:
106106
d2, d1, d11, d111, d21, d211, d2111, n1, n2, n3, n4, n5, n6 = mesh.active_quadrangles(d)
107-
107+
# if degree are
108108
if not test_degree(n5) or not test_degree(n3):
109109
topo = False
110110
return topo, geo
111111

112112
if d211.get_node() == d111.get_node() or d11.get_node() == d2111.get_node():
113113
topo = False
114114
return topo, geo
115-
116115
topo = isValidQuad(n5, n6, n2, n3) and isValidQuad(n1, n5, n3, n4)
117116

118117
"""
@@ -174,6 +173,8 @@ def isCollapseOk(d: Dart) -> (bool, bool):
174173
f3 = (d11.get_beta(2)).get_face()
175174
f4 = (d111.get_beta(2)).get_face()
176175
adjacent_faces_lst=[f1.id, f2.id, f3.id, f4.id]
176+
177+
# Check that there are no adjacent faces in common
177178
if len(adjacent_faces_lst) != len(set(adjacent_faces_lst)):
178179
topo = False
179180
return topo, geo
@@ -251,6 +252,13 @@ def isValidQuad(A: Node, B: Node, C: Node, D: Node):
251252
u3 = np.array([D.x() - C.x(), D.y() - C.y()]) # vect(CD)
252253
u4 = np.array([A.x() - D.x(), A.y() - D.y()]) # vect(DA)
253254

255+
# Checking for near-zero vectors (close to (0,0))
256+
if (np.linalg.norm(u1) < 1e-5 or
257+
np.linalg.norm(u2) < 1e-5 or
258+
np.linalg.norm(u3) < 1e-5 or
259+
np.linalg.norm(u4) < 1e-5):
260+
return False # Quad invalid because one side is almost zero
261+
254262
cp_A = cross_product(-1*u4, u1)
255263
cp_B = cross_product(-1*u1, u2)
256264
cp_C = cross_product(-1*u2, u3)
@@ -263,13 +271,12 @@ def isValidQuad(A: Node, B: Node, C: Node, D: Node):
263271
return False
264272
265273
"""
266-
if signe(cp_A)+signe(cp_B)+signe(cp_C)+signe(cp_D)<2:
274+
if 0< signe(cp_A)+signe(cp_B)+signe(cp_C)+signe(cp_D) <2 :
267275
return True
268276
else:
269277
return False
270278

271279

272-
273280
def orientation(p, q, r):
274281
""" Calcule l'orientation de trois points.
275282
Retourne :

mesh_model/mesh_struct/mesh_elements.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
2+
import numpy as np
33

44
class Dart:
55
_mesh_type: type = None
@@ -19,6 +19,8 @@ def __init__(self, m: _mesh_type, dart_id: int):
1919
:param id: a index, that corresponds to the location of the dart data in the mesh_struct dart container
2020
"""
2121
self.mesh = m
22+
if not isinstance(dart_id, (int, np.integer)):
23+
raise ValueError(f"The id must be an integer, {dart_id} is type {type(dart_id)}.")
2224
self.id = dart_id
2325

2426
def __eq__(self, a_dart: Dart) -> bool:
@@ -40,7 +42,6 @@ def get_beta(self, i: int) -> Dart:
4042
"""
4143
if i < 1 or i > 2:
4244
raise ValueError("Wrong alpha dimension")
43-
4445
if self.mesh.dart_info[self.id, i] == -1:
4546
return None
4647

@@ -135,7 +136,7 @@ def get_dart(self) -> Dart:
135136
Get the dart value associated with this node
136137
:return: a dart
137138
"""
138-
return Dart(self.mesh, self.mesh.nodes[self.id, 2])
139+
return Dart(self.mesh, int(self.mesh.nodes[self.id, 2]))
139140

140141
def x(self) -> float:
141142
"""

mesh_model/random_quadmesh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def random_mesh() -> Mesh:
1515
:param num_nodes_max: number of nodes of the final mesh
1616
:return: a random mesh
1717
"""
18-
filename = os.path.join('../mesh_files/', 't1_quad.msh')
19-
mesh = read_gmsh(filename)
18+
#filename = os.path.join('../mesh_files/', 't1_quad.msh')
19+
mesh = read_gmsh("/home/ropercha/PycharmProjects/tune/mesh_files/t1_quad.msh")
2020
mesh_shuffle(mesh, 10)
2121
return mesh
2222

model_RL/PPO_model_new.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

0 commit comments

Comments
 (0)