Skip to content

Commit ed50fd8

Browse files
committed
Add YAML configuration files for parameters
Enable rendering in QuadMesh environment Add posibility to use SB3 vectorized environments Refactor ObservationRegistry to use a pandas DataFrame
1 parent e118d27 commit ed50fd8

File tree

11 files changed

+581
-264
lines changed

11 files changed

+581
-264
lines changed

environment/gymnasium_envs/quadmesh_env/envs/quadmesh.py

Lines changed: 191 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
from enum import Enum
2-
import gymnasium as gym
3-
from gymnasium import spaces
4-
import numpy as np
1+
52
import copy
3+
import pygame
4+
import imageio
5+
import sys
6+
7+
import numpy as np
8+
import gymnasium as gym
9+
10+
from enum import Enum
11+
from typing import Optional
12+
from pygame.locals import *
613

714
from mesh_model.random_quadmesh import random_mesh
815
from mesh_model.mesh_struct.mesh_elements import Dart
@@ -11,6 +18,8 @@
1118
from environment.gymnasium_envs.quadmesh_env.envs.mesh_conv import get_x
1219
from environment.actions.quadrangular_actions import flip_edge_cntcw, flip_edge_cw, split_edge, collapse_edge, cleanup_edge
1320
from environment.observation_register import ObservationRegistry
21+
from view.window import window_data, graph
22+
from mesh_display import MeshDisplay
1423

1524

1625
class Actions(Enum):
@@ -22,41 +31,100 @@ class Actions(Enum):
2231

2332

2433
class QuadMeshEnv(gym.Env):
25-
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 60}
34+
"""
35+
QuadMesh environment is structured according to gymnasium and is used to topologically optimize quadrangular meshes topologically.
36+
The generated observations consist of a local topological view of the mesh. They are structured in the form of matrices :
37+
* The columns correspond to the surrounding area of a mesh dart.
38+
* Only the darts with the most irregularities in the surrounding area are retained.
39+
40+
Based on these observations, the agent will choose from 4 different actions:
41+
* flip clockwise, flip an edge clockwise
42+
* flip counterclockwise, flip an edge counterclockwise
43+
* split, add a face
44+
* collapse, deleting a face
45+
46+
These actions will generate rewards proportional to the improvement or deterioration of the mesh. If the chosen action is invalid, a penalty is returned.
47+
"""
48+
49+
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
2650

27-
def __init__(self, mesh=None, max_episode_steps=30, n_darts_selected=20, deep=6, with_degree_obs=True, action_restriction=False, render_mode=None):
51+
def __init__(
52+
self,
53+
mesh=None,
54+
max_episode_steps: int =50,
55+
n_darts_selected: int =20,
56+
deep: int =6,
57+
render_mode: Optional[str] = None,
58+
with_degree_obs: bool =True,
59+
action_restriction: bool =False,
60+
obs_count: bool=False,
61+
) -> None:
62+
63+
assert render_mode is None or render_mode in self.metadata["render_modes"]
64+
self.render_mode = render_mode
65+
66+
#If a mesh has been entered, it is used, otherwise a random mesh is generated.
2867
if mesh is not None:
2968
self.config = {"mesh": mesh}
3069
self.mesh = copy.deepcopy(mesh)
3170
else :
3271
self.config = {"mesh": None}
3372
self.mesh = random_mesh()
34-
self.mesh_size = len(self.mesh.nodes)
35-
self.nb_darts = len(self.mesh.dart_info)
73+
74+
#self.mesh_size = len(self.mesh.nodes)
75+
#self.nb_darts = len(self.mesh.dart_info)
3676
self._nodes_scores, self._mesh_score, self._ideal_score, self._nodes_adjacency = global_score(self.mesh)
37-
self._ideal_rewards = (self._mesh_score - self._ideal_score)*10
77+
self._ideal_rewards = (self._mesh_score - self._ideal_score)*10 #arbitrary factor of 10 for rewards
3878
self.next_mesh_score = 0
3979
self.n_darts_selected = n_darts_selected
4080
self.restricted = action_restriction
4181
self.degree_observation = with_degree_obs
42-
self.window_size = 512 # The size of the PyGame window
43-
self.g = None
4482
self.nb_invalid_actions = 0
4583
self.max_steps = max_episode_steps
84+
self.episode_count = 0
85+
self.ep_len = 0
4686
self.darts_selected = [] # darts id observed
4787
self.deep = deep*2 if self.degree_observation else deep
48-
self.observation_space = spaces.Box(
88+
self.actions_info = {
89+
"n_flip_cntcw": 0,
90+
"n_flip_ccw": 0,
91+
"n_split": 0,
92+
"n_collapse": 0,
93+
"n_cleanup": 0,
94+
}
95+
96+
# Definition of an observation register if required
97+
if obs_count:
98+
self.observation_count = True
99+
self.observation_registry = ObservationRegistry(self.n_darts_selected, self.deep, -6, 2)
100+
else:
101+
self.observation_count = False
102+
103+
# Render
104+
if self.render_mode == "human":
105+
self.mesh_disp = MeshDisplay(self.mesh)
106+
self.graph = graph.Graph(self.mesh_disp.get_nodes_coordinates(), self.mesh_disp.get_edges(),
107+
self.mesh_disp.get_scores())
108+
self.win_data = window_data()
109+
self.window_size = 512 # The size of the PyGame window
110+
self.window = None
111+
self.clock = None
112+
113+
self.recording = False
114+
self.frames = []
115+
116+
# Observation and action spaces
117+
self.observation_space = gym.spaces.Box(
49118
low=-6, # nodes min degree : -6
50119
high=2, # nodes max degree : 2
51120
shape=(self.n_darts_selected, deep),
52121
dtype=np.int64
53122
)
54-
self.observation_count = ObservationRegistry(self.n_darts_selected, self.deep, -6, 2)
55-
56123
self.observation = None
57124

58-
# We have 4 actions, flip clockwise, flip counterclockwise, split, collapse, cleanup
59-
self.action_space = spaces.MultiDiscrete([4, self.n_darts_selected])
125+
# We have 4 actions, flip clockwise, flip counterclockwise, split, collapse
126+
self.action_space = gym.spaces.MultiDiscrete([4, self.n_darts_selected])
127+
60128

61129

62130
def reset(self, seed=None, options=None):
@@ -68,16 +136,28 @@ def reset(self, seed=None, options=None):
68136
self.mesh = copy.deepcopy(self.config["mesh"])
69137
else:
70138
self.mesh = random_mesh()
71-
self.nb_darts = len(self.mesh.dart_info)
139+
#self.nb_darts = len(self.mesh.dart_info)
72140
self._nodes_scores, self._mesh_score, self._ideal_score, self._nodes_adjacency = global_score(self.mesh)
73141
self._ideal_rewards = (self._mesh_score - self._ideal_score) * 10
74142
self.nb_invalid_actions = 0
75143
self.close()
76144
self.observation = self._get_obs()
145+
self.ep_len = 0
77146
info = self._get_info(terminated=False,valid_act=(None,None,None), action=(None,None), mesh_reward=None)
147+
self.actions_info = {
148+
"n_flip_cw": 0,
149+
"n_flip_cntcw": 0,
150+
"n_split": 0,
151+
"n_collapse": 0,
152+
"n_cleanup": 0,
153+
}
78154

79-
if self.render_mode == "human":
155+
if self.render_mode=="human":
80156
self._render_frame()
157+
self.recording = True
158+
else:
159+
self.recording = False
160+
self.frames = []
81161

82162
return self.observation, info
83163

@@ -108,7 +188,7 @@ def _get_info(self, terminated, valid_act, action, mesh_reward):
108188
"invalid_cleanup": 1.0 if action[0]==Actions.CLEANUP.value and not valid_action else 0.0,
109189
"mesh" : self.mesh,
110190
"darts_selected" : self.darts_selected,
111-
"observation_count" : self.observation_count,
191+
"observation_registry" : self.observation_registry if self.observation_count else None,
112192
}
113193

114194
def _action_to_dart_id(self, action: np.ndarray) -> int:
@@ -120,27 +200,34 @@ def _action_to_dart_id(self, action: np.ndarray) -> int:
120200
return self.darts_selected[int(action[1])]
121201

122202
def step(self, action: np.ndarray):
203+
self.ep_len+=1
123204
dart_id = self._action_to_dart_id(action)
124205
d = Dart(self.mesh, dart_id)
125206
d1 = d.get_beta(1)
126207
n1 = d.get_node()
127208
n2 = d1.get_node()
128209
valid_action, valid_topo, valid_geo = False, False, False
129-
130210
if action[0] == Actions.FLIP_CW.value:
211+
self.actions_info["n_flip_cw"] += 1
131212
valid_action, valid_topo, valid_geo = flip_edge_cw(self.mesh, n1, n2)
132213
elif action[0] == Actions.FLIP_CNTCW.value:
214+
self.actions_info["n_flip_cntcw"] += 1
133215
valid_action, valid_topo, valid_geo = flip_edge_cntcw(self.mesh, n1, n2)
134216
elif action[0] == Actions.SPLIT.value:
217+
self.actions_info["n_split"] += 1
135218
valid_action, valid_topo, valid_geo = split_edge(self.mesh, n1, n2)
136219
elif action[0] == Actions.COLLAPSE.value:
220+
self.actions_info["n_collapse"] += 1
137221
valid_action, valid_topo, valid_geo = collapse_edge(self.mesh, n1, n2)
138222
elif action[0] == Actions.CLEANUP.value:
223+
self.actions_info["n_cleanup"] += 1
139224
valid_action, valid_topo, valid_geo = cleanup_edge(self.mesh, n1, n2)
140225
else:
141226
raise ValueError("Action not defined")
142227

143-
self.observation_count.register_observation(self.observation)
228+
if self.observation_count:
229+
self.observation_registry.register_observation(self.observation)
230+
144231
if valid_action:
145232
# An episode is done if the actual score is the same as the ideal
146233
next_nodes_score, self.next_mesh_score, _, next_nodes_adjacency = global_score(self.mesh)
@@ -171,6 +258,89 @@ def step(self, action: np.ndarray):
171258
else:
172259
truncated = False
173260
valid_act = valid_action, valid_topo, valid_geo
261+
174262
info = self._get_info(terminated, valid_act, action, mesh_reward)
175263

264+
if self.render_mode == "human":
265+
self._render_frame()
266+
if terminated or self.ep_len>= self.max_steps:
267+
if self.recording and self.frames:
268+
imageio.mimsave(f"episode_{self.episode_count}.gif", self.frames, fps=1)
269+
print("Image recorded")
270+
self.episode_count +=1
271+
176272
return self.observation, reward, terminated, truncated, info
273+
274+
275+
def _render_frame(self):
276+
if self.render_mode == "human" and self.window is None:
277+
pygame.init()
278+
pygame.display.init()
279+
self.window = pygame.display.set_mode(self.win_data.size, self.win_data.options)
280+
pygame.display.set_caption('QuadMesh')
281+
self.window.fill((255, 255, 255))
282+
self.font = pygame.font.SysFont(None, self.win_data.font_size)
283+
self.clock = pygame.time.Clock()
284+
self.clock.tick(60)
285+
self.win_data.scene_xmin, self.win_data.scene_ymin, self.win_data.scene_xmax, self.win_data.scene_ymax = self.graph.bounding_box()
286+
self.win_data.scene_center = pygame.math.Vector2((self.win_data.scene_xmax + self.win_data.scene_xmin) / 2.0,
287+
(self.win_data.scene_ymax + self.win_data.scene_ymin) / 2.0)
288+
289+
pygame.event.pump()
290+
self.window.fill((255, 255, 255)) # white
291+
for event in pygame.event.get():
292+
if event.type == QUIT:
293+
pygame.quit()
294+
sys.exit()
295+
296+
if event.type == VIDEORESIZE or event.type == VIDEOEXPOSE: # handles window minimising/maximising
297+
x, y = self.window.get_size()
298+
text_margin = 200
299+
self.win_data.center.x = (x - text_margin) / 2
300+
self.win_data.center.y = y / 2
301+
ratio = float(x - text_margin) / float(self.win_data.scene_xmax - self.win_data.scene_xmin)
302+
ratio_y = float(y) / float(self.win_data.scene_ymax - self.win_data.scene_ymin)
303+
if ratio_y < ratio:
304+
ratio = ratio_y
305+
306+
self.win_data.node_size = max(ratio / 100, 10)
307+
self.win_data.stretch = 0.75 * ratio
308+
309+
self.window.fill((255, 255, 255))
310+
pygame.display.flip()
311+
312+
self.graph.clear()
313+
self.mesh_disp = MeshDisplay(self.mesh)
314+
self.graph.update(self.mesh_disp.get_nodes_coordinates(), self.mesh_disp.get_edges(),
315+
self.mesh_disp.get_scores())
316+
317+
#Draw mesh
318+
for e in self.graph.edges:
319+
e.draw(self.window, self.win_data)
320+
for n in self.graph.vertices:
321+
n.draw(self.window, self.font, self.win_data)
322+
323+
#Print action type
324+
if hasattr(self, 'actions_info'):
325+
x = self.window.get_width() - 150
326+
y_start = 100
327+
line_spacing = 25
328+
329+
for i, (action_name, count) in enumerate(self.actions_info.items()):
330+
text = f"{action_name}: {count}"
331+
text_surface = self.font.render(text, True, (0, 0, 0))
332+
self.window.blit(text_surface, (x, y_start + i * line_spacing))
333+
334+
self.clock.tick(60)
335+
pygame.time.delay(1200)
336+
pygame.display.flip()
337+
if self.recording:
338+
pixels = pygame.surfarray.array3d(pygame.display.get_surface())
339+
frame = pixels.transpose([1,0,2])
340+
self.frames.append(frame)
341+
342+
def close(self):
343+
if self.render_mode=="human" and self.window is not None:
344+
pygame.display.quit()
345+
pygame.quit()
346+
self.window = None

environment/observation_register.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import pandas as pd
2+
13

24
class ObservationRegistry:
35
def __init__(self, n_darts_selected, deep, lowest_value, highest_value):
46
self.n_darts = n_darts_selected
57
self.deep = deep
68
self.low = lowest_value
79
self.high = highest_value
8-
self.counts = {}
10+
self.df = pd.DataFrame(columns=["counts"])
11+
self.df.index.name = "observations"
912

1013
def encode(self, observation):
1114
"""
@@ -30,10 +33,14 @@ def register_observation(self, observation) -> None:
3033
:param observation:
3134
"""
3235
ID = self.encode(observation)
33-
if self.counts.get(ID) is None:
34-
self.counts[ID] = 1
36+
37+
if self.df.empty:
38+
self.df = pd.DataFrame({"counts": [1]}, index=[ID])
39+
elif ID in self.df.index:
40+
self.df.at[ID, "counts"] += 1
3541
else:
36-
self.counts[ID] += 1
42+
new_row = pd.DataFrame({"counts": [1]}, index=[ID])
43+
self.df = pd.concat([self.df, new_row])
3744

3845
def get_count(self, observation):
3946
"""
@@ -42,5 +49,20 @@ def get_count(self, observation):
4249
:return: counts of observations
4350
"""
4451
ID = self.encode(observation)
45-
return self.counts.get(ID, 0)
52+
return int(self.df.loc[ID])
4653

54+
def save(self, path):
55+
if path.endswith(".csv"):
56+
self.df.to_csv(path)
57+
elif path.endswith(".parquet"):
58+
self.df.to_parquet(path)
59+
else:
60+
print("Unsupported file type")
61+
62+
def load_counts(self, path):
63+
if path.endswith(".csv"):
64+
self.df = pd.read_csv(path)
65+
elif path.endswith(".parquet"):
66+
self.df = pd.read_parquet(path)
67+
else:
68+
print("Unsupported file type")

model_RL/PPO_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Old version of PPO for triangular environement
3+
"""
4+
15
from model_RL.utilities.actor_critic_networks import NaNExceptionActor, NaNExceptionCritic, Actor, Critic
26
from mesh_model.mesh_analysis.global_mesh_analysis import global_score
37
import copy

model_RL/PPO_model_pers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,4 +275,4 @@ def learn(self, writer):
275275
print("NaN Exception on Critic Network")
276276
return None, None, None, None
277277

278-
return self.actor, rewards, wins, len_ep, info["observation_count"]
278+
return self.actor, rewards, wins, len_ep, info["observation_registry"]

0 commit comments

Comments
 (0)