99from mesh_model .mesh_analysis .quadmesh_analysis import global_score , isTruncated
1010from environment .gymnasium_envs .quadmesh_env .envs .mesh_conv import get_x
1111from environment .actions .quadrangular_actions import flip_edge , split_edge , collapse_edge , cleanup_edge
12+ from environment .observation_register import ObservationRegistry
1213
1314
1415class Actions (Enum ):
@@ -21,7 +22,7 @@ class Actions(Enum):
2122class QuadMeshEnv (gym .Env ):
2223 metadata = {"render_modes" : ["human" , "rgb_array" ], "render_fps" : 60 }
2324
24- def __init__ (self , mesh = None , n_darts_selected = 20 , deep = 6 , with_degree_obs = True , action_restriction = False , render_mode = None ):
25+ def __init__ (self , mesh = None , max_episode_steps = 30 , n_darts_selected = 20 , deep = 6 , with_degree_obs = True , action_restriction = False , render_mode = None ):
2526 if mesh is not None :
2627 self .config = {"mesh" : mesh }
2728 self .mesh = copy .deepcopy (mesh )
@@ -33,21 +34,22 @@ def __init__(self, mesh=None, n_darts_selected=20, deep=6, with_degree_obs=True,
3334 self ._nodes_scores , self ._mesh_score , self ._ideal_score , self ._nodes_adjacency = global_score (self .mesh )
3435 self ._ideal_rewards = (self ._mesh_score - self ._ideal_score )* 10
3536 self .next_mesh_score = 0
36- self .deep = deep
3737 self .n_darts_selected = n_darts_selected
3838 self .restricted = action_restriction
3939 self .degree_observation = with_degree_obs
4040 self .window_size = 512 # The size of the PyGame window
4141 self .g = None
4242 self .nb_invalid_actions = 0
43+ self .max_steps = max_episode_steps
4344 self .darts_selected = [] # darts id observed
44- deep = self . deep * 2 if self .degree_observation else deep
45+ self . deep = deep * 2 if self .degree_observation else deep
4546 self .observation_space = spaces .Box (
46- low = - 15 , # nodes min degree : -15
47- high = 15 , # nodes max degree : 15
48- shape = (self .n_darts_selected , self . deep * 2 if self . degree_observation else self . deep ),
47+ low = - 6 , # nodes min degree : -15
48+ high = 2 , # nodes max degree : 15
49+ shape = (self .n_darts_selected , deep ),
4950 dtype = np .int64
5051 )
52+ self .observation_count = ObservationRegistry (self .n_darts_selected , self .deep , - 6 , 2 )
5153
5254 self .observation = None
5355
@@ -102,6 +104,8 @@ def _get_info(self, terminated, valid_act, action, mesh_reward):
102104 "invalid_collapse" : 1.0 if action [0 ]== Actions .COLLAPSE .value and not valid_action else 0.0 ,
103105 "invalid_cleanup" : 1.0 if action [0 ]== Actions .CLEANUP .value and not valid_action else 0.0 ,
104106 "mesh" : self .mesh ,
107+ "darts_selected" : self .darts_selected ,
108+ "observation_count" : self .observation_count ,
105109 }
106110
107111 def _action_to_dart_id (self , action : np .ndarray ) -> int :
@@ -131,6 +135,7 @@ def step(self, action: np.ndarray):
131135 else :
132136 raise ValueError ("Action not defined" )
133137
138+ self .observation_count .register_observation (self .observation )
134139 if valid_action :
135140 # An episode is done if the actual score is the same as the ideal
136141 next_nodes_score , self .next_mesh_score , _ , next_nodes_adjacency = global_score (self .mesh )
0 commit comments