1- from enum import Enum
2- import gymnasium as gym
3- from gymnasium import spaces
4- import numpy as np
1+
52import copy
3+ import pygame
4+ import imageio
5+ import sys
6+
7+ import numpy as np
8+ import gymnasium as gym
9+
10+ from enum import Enum
11+ from typing import Optional
12+ from pygame .locals import *
613
714from mesh_model .random_quadmesh import random_mesh
815from mesh_model .mesh_struct .mesh_elements import Dart
1118from environment .gymnasium_envs .quadmesh_env .envs .mesh_conv import get_x
1219from environment .actions .quadrangular_actions import flip_edge_cntcw , flip_edge_cw , split_edge , collapse_edge , cleanup_edge
1320from environment .observation_register import ObservationRegistry
21+ from view .window import window_data , graph
22+ from mesh_display import MeshDisplay
1423
1524
1625class Actions (Enum ):
@@ -22,41 +31,100 @@ class Actions(Enum):
2231
2332
2433class QuadMeshEnv (gym .Env ):
25- metadata = {"render_modes" : ["human" , "rgb_array" ], "render_fps" : 60 }
34+ """
35+ QuadMesh environment is structured according to gymnasium and is used to topologically optimize quadrangular meshes topologically.
36+ The generated observations consist of a local topological view of the mesh. They are structured in the form of matrices :
37+ * The columns correspond to the surrounding area of a mesh dart.
38+ * Only the darts with the most irregularities in the surrounding area are retained.
39+
40+ Based on these observations, the agent will choose from 4 different actions:
41+ * flip clockwise, flip an edge clockwise
42+ * flip counterclockwise, flip an edge counterclockwise
43+ * split, add a face
44+ * collapse, deleting a face
45+
46+ These actions will generate rewards proportional to the improvement or deterioration of the mesh. If the chosen action is invalid, a penalty is returned.
47+ """
48+
49+ metadata = {"render_modes" : ["human" , "rgb_array" ], "render_fps" : 30 }
2650
27- def __init__ (self , mesh = None , max_episode_steps = 30 , n_darts_selected = 20 , deep = 6 , with_degree_obs = True , action_restriction = False , render_mode = None ):
51+ def __init__ (
52+ self ,
53+ mesh = None ,
54+ max_episode_steps : int = 50 ,
55+ n_darts_selected : int = 20 ,
56+ deep : int = 6 ,
57+ render_mode : Optional [str ] = None ,
58+ with_degree_obs : bool = True ,
59+ action_restriction : bool = False ,
60+ obs_count : bool = False ,
61+ ) -> None :
62+
63+ assert render_mode is None or render_mode in self .metadata ["render_modes" ]
64+ self .render_mode = render_mode
65+
66+ #If a mesh has been entered, it is used, otherwise a random mesh is generated.
2867 if mesh is not None :
2968 self .config = {"mesh" : mesh }
3069 self .mesh = copy .deepcopy (mesh )
3170 else :
3271 self .config = {"mesh" : None }
3372 self .mesh = random_mesh ()
34- self .mesh_size = len (self .mesh .nodes )
35- self .nb_darts = len (self .mesh .dart_info )
73+
74+ #self.mesh_size = len(self.mesh.nodes)
75+ #self.nb_darts = len(self.mesh.dart_info)
3676 self ._nodes_scores , self ._mesh_score , self ._ideal_score , self ._nodes_adjacency = global_score (self .mesh )
37- self ._ideal_rewards = (self ._mesh_score - self ._ideal_score )* 10
77+ self ._ideal_rewards = (self ._mesh_score - self ._ideal_score )* 10 #arbitrary factor of 10 for rewards
3878 self .next_mesh_score = 0
3979 self .n_darts_selected = n_darts_selected
4080 self .restricted = action_restriction
4181 self .degree_observation = with_degree_obs
42- self .window_size = 512 # The size of the PyGame window
43- self .g = None
4482 self .nb_invalid_actions = 0
4583 self .max_steps = max_episode_steps
84+ self .episode_count = 0
85+ self .ep_len = 0
4686 self .darts_selected = [] # darts id observed
4787 self .deep = deep * 2 if self .degree_observation else deep
48- self .observation_space = spaces .Box (
88+ self .actions_info = {
89+ "n_flip_cntcw" : 0 ,
90+ "n_flip_ccw" : 0 ,
91+ "n_split" : 0 ,
92+ "n_collapse" : 0 ,
93+ "n_cleanup" : 0 ,
94+ }
95+
96+ # Definition of an observation register if required
97+ if obs_count :
98+ self .observation_count = True
99+ self .observation_registry = ObservationRegistry (self .n_darts_selected , self .deep , - 6 , 2 )
100+ else :
101+ self .observation_count = False
102+
103+ # Render
104+ if self .render_mode == "human" :
105+ self .mesh_disp = MeshDisplay (self .mesh )
106+ self .graph = graph .Graph (self .mesh_disp .get_nodes_coordinates (), self .mesh_disp .get_edges (),
107+ self .mesh_disp .get_scores ())
108+ self .win_data = window_data ()
109+ self .window_size = 512 # The size of the PyGame window
110+ self .window = None
111+ self .clock = None
112+
113+ self .recording = False
114+ self .frames = []
115+
116+ # Observation and action spaces
117+ self .observation_space = gym .spaces .Box (
49118 low = - 6 , # nodes min degree : -6
50119 high = 2 , # nodes max degree : 2
51120 shape = (self .n_darts_selected , deep ),
52121 dtype = np .int64
53122 )
54- self .observation_count = ObservationRegistry (self .n_darts_selected , self .deep , - 6 , 2 )
55-
56123 self .observation = None
57124
58- # We have 4 actions, flip clockwise, flip counterclockwise, split, collapse, cleanup
59- self .action_space = spaces .MultiDiscrete ([4 , self .n_darts_selected ])
125+ # We have 4 actions, flip clockwise, flip counterclockwise, split, collapse
126+ self .action_space = gym .spaces .MultiDiscrete ([4 , self .n_darts_selected ])
127+
60128
61129
62130 def reset (self , seed = None , options = None ):
@@ -68,16 +136,28 @@ def reset(self, seed=None, options=None):
68136 self .mesh = copy .deepcopy (self .config ["mesh" ])
69137 else :
70138 self .mesh = random_mesh ()
71- self .nb_darts = len (self .mesh .dart_info )
139+ # self.nb_darts = len(self.mesh.dart_info)
72140 self ._nodes_scores , self ._mesh_score , self ._ideal_score , self ._nodes_adjacency = global_score (self .mesh )
73141 self ._ideal_rewards = (self ._mesh_score - self ._ideal_score ) * 10
74142 self .nb_invalid_actions = 0
75143 self .close ()
76144 self .observation = self ._get_obs ()
145+ self .ep_len = 0
77146 info = self ._get_info (terminated = False ,valid_act = (None ,None ,None ), action = (None ,None ), mesh_reward = None )
147+ self .actions_info = {
148+ "n_flip_cw" : 0 ,
149+ "n_flip_cntcw" : 0 ,
150+ "n_split" : 0 ,
151+ "n_collapse" : 0 ,
152+ "n_cleanup" : 0 ,
153+ }
78154
79- if self .render_mode == "human" :
155+ if self .render_mode == "human" :
80156 self ._render_frame ()
157+ self .recording = True
158+ else :
159+ self .recording = False
160+ self .frames = []
81161
82162 return self .observation , info
83163
@@ -108,7 +188,7 @@ def _get_info(self, terminated, valid_act, action, mesh_reward):
108188 "invalid_cleanup" : 1.0 if action [0 ]== Actions .CLEANUP .value and not valid_action else 0.0 ,
109189 "mesh" : self .mesh ,
110190 "darts_selected" : self .darts_selected ,
111- "observation_count " : self .observation_count ,
191+ "observation_registry " : self .observation_registry if self . observation_count else None ,
112192 }
113193
114194 def _action_to_dart_id (self , action : np .ndarray ) -> int :
@@ -120,27 +200,34 @@ def _action_to_dart_id(self, action: np.ndarray) -> int:
120200 return self .darts_selected [int (action [1 ])]
121201
122202 def step (self , action : np .ndarray ):
203+ self .ep_len += 1
123204 dart_id = self ._action_to_dart_id (action )
124205 d = Dart (self .mesh , dart_id )
125206 d1 = d .get_beta (1 )
126207 n1 = d .get_node ()
127208 n2 = d1 .get_node ()
128209 valid_action , valid_topo , valid_geo = False , False , False
129-
130210 if action [0 ] == Actions .FLIP_CW .value :
211+ self .actions_info ["n_flip_cw" ] += 1
131212 valid_action , valid_topo , valid_geo = flip_edge_cw (self .mesh , n1 , n2 )
132213 elif action [0 ] == Actions .FLIP_CNTCW .value :
214+ self .actions_info ["n_flip_cntcw" ] += 1
133215 valid_action , valid_topo , valid_geo = flip_edge_cntcw (self .mesh , n1 , n2 )
134216 elif action [0 ] == Actions .SPLIT .value :
217+ self .actions_info ["n_split" ] += 1
135218 valid_action , valid_topo , valid_geo = split_edge (self .mesh , n1 , n2 )
136219 elif action [0 ] == Actions .COLLAPSE .value :
220+ self .actions_info ["n_collapse" ] += 1
137221 valid_action , valid_topo , valid_geo = collapse_edge (self .mesh , n1 , n2 )
138222 elif action [0 ] == Actions .CLEANUP .value :
223+ self .actions_info ["n_cleanup" ] += 1
139224 valid_action , valid_topo , valid_geo = cleanup_edge (self .mesh , n1 , n2 )
140225 else :
141226 raise ValueError ("Action not defined" )
142227
143- self .observation_count .register_observation (self .observation )
228+ if self .observation_count :
229+ self .observation_registry .register_observation (self .observation )
230+
144231 if valid_action :
145232 # An episode is done if the actual score is the same as the ideal
146233 next_nodes_score , self .next_mesh_score , _ , next_nodes_adjacency = global_score (self .mesh )
@@ -171,6 +258,89 @@ def step(self, action: np.ndarray):
171258 else :
172259 truncated = False
173260 valid_act = valid_action , valid_topo , valid_geo
261+
174262 info = self ._get_info (terminated , valid_act , action , mesh_reward )
175263
264+ if self .render_mode == "human" :
265+ self ._render_frame ()
266+ if terminated or self .ep_len >= self .max_steps :
267+ if self .recording and self .frames :
268+ imageio .mimsave (f"episode_{ self .episode_count } .gif" , self .frames , fps = 1 )
269+ print ("Image recorded" )
270+ self .episode_count += 1
271+
176272 return self .observation , reward , terminated , truncated , info
273+
274+
275+ def _render_frame (self ):
276+ if self .render_mode == "human" and self .window is None :
277+ pygame .init ()
278+ pygame .display .init ()
279+ self .window = pygame .display .set_mode (self .win_data .size , self .win_data .options )
280+ pygame .display .set_caption ('QuadMesh' )
281+ self .window .fill ((255 , 255 , 255 ))
282+ self .font = pygame .font .SysFont (None , self .win_data .font_size )
283+ self .clock = pygame .time .Clock ()
284+ self .clock .tick (60 )
285+ self .win_data .scene_xmin , self .win_data .scene_ymin , self .win_data .scene_xmax , self .win_data .scene_ymax = self .graph .bounding_box ()
286+ self .win_data .scene_center = pygame .math .Vector2 ((self .win_data .scene_xmax + self .win_data .scene_xmin ) / 2.0 ,
287+ (self .win_data .scene_ymax + self .win_data .scene_ymin ) / 2.0 )
288+
289+ pygame .event .pump ()
290+ self .window .fill ((255 , 255 , 255 )) # white
291+ for event in pygame .event .get ():
292+ if event .type == QUIT :
293+ pygame .quit ()
294+ sys .exit ()
295+
296+ if event .type == VIDEORESIZE or event .type == VIDEOEXPOSE : # handles window minimising/maximising
297+ x , y = self .window .get_size ()
298+ text_margin = 200
299+ self .win_data .center .x = (x - text_margin ) / 2
300+ self .win_data .center .y = y / 2
301+ ratio = float (x - text_margin ) / float (self .win_data .scene_xmax - self .win_data .scene_xmin )
302+ ratio_y = float (y ) / float (self .win_data .scene_ymax - self .win_data .scene_ymin )
303+ if ratio_y < ratio :
304+ ratio = ratio_y
305+
306+ self .win_data .node_size = max (ratio / 100 , 10 )
307+ self .win_data .stretch = 0.75 * ratio
308+
309+ self .window .fill ((255 , 255 , 255 ))
310+ pygame .display .flip ()
311+
312+ self .graph .clear ()
313+ self .mesh_disp = MeshDisplay (self .mesh )
314+ self .graph .update (self .mesh_disp .get_nodes_coordinates (), self .mesh_disp .get_edges (),
315+ self .mesh_disp .get_scores ())
316+
317+ #Draw mesh
318+ for e in self .graph .edges :
319+ e .draw (self .window , self .win_data )
320+ for n in self .graph .vertices :
321+ n .draw (self .window , self .font , self .win_data )
322+
323+ #Print action type
324+ if hasattr (self , 'actions_info' ):
325+ x = self .window .get_width () - 150
326+ y_start = 100
327+ line_spacing = 25
328+
329+ for i , (action_name , count ) in enumerate (self .actions_info .items ()):
330+ text = f"{ action_name } : { count } "
331+ text_surface = self .font .render (text , True , (0 , 0 , 0 ))
332+ self .window .blit (text_surface , (x , y_start + i * line_spacing ))
333+
334+ self .clock .tick (60 )
335+ pygame .time .delay (1200 )
336+ pygame .display .flip ()
337+ if self .recording :
338+ pixels = pygame .surfarray .array3d (pygame .display .get_surface ())
339+ frame = pixels .transpose ([1 ,0 ,2 ])
340+ self .frames .append (frame )
341+
342+ def close (self ):
343+ if self .render_mode == "human" and self .window is not None :
344+ pygame .display .quit ()
345+ pygame .quit ()
346+ self .window = None
0 commit comments