99from mesh_model .mesh_analysis .global_mesh_analysis import global_score
1010from mesh_model .mesh_analysis .quadmesh_analysis import isTruncated
1111from environment .gymnasium_envs .quadmesh_env .envs .mesh_conv import get_x
12- from environment .actions .quadrangular_actions import flip_edge , split_edge , collapse_edge , cleanup_edge
12+ from environment .actions .quadrangular_actions import flip_edge_cntcw , flip_edge_cw , split_edge , collapse_edge , cleanup_edge
1313from environment .observation_register import ObservationRegistry
1414
1515
1616class Actions (Enum ):
17- FLIP = 0
18- SPLIT = 1
19- COLLAPSE = 2
20- CLEANUP = 3
17+ FLIP_CW = 0
18+ FLIP_CNTCW = 1
19+ SPLIT = 2
20+ COLLAPSE = 3
21+ CLEANUP = 4
2122
2223
2324class QuadMeshEnv (gym .Env ):
@@ -45,17 +46,17 @@ def __init__(self, mesh=None, max_episode_steps=30, n_darts_selected=20, deep=6,
4546 self .darts_selected = [] # darts id observed
4647 self .deep = deep * 2 if self .degree_observation else deep
4748 self .observation_space = spaces .Box (
48- low = - 6 , # nodes min degree : -15
49- high = 2 , # nodes max degree : 15
49+ low = - 6 , # nodes min degree : -6
50+ high = 2 , # nodes max degree : 2
5051 shape = (self .n_darts_selected , deep ),
5152 dtype = np .int64
5253 )
5354 self .observation_count = ObservationRegistry (self .n_darts_selected , self .deep , - 6 , 2 )
5455
5556 self .observation = None
5657
57- # We have 4 actions, flip, split, collapse, cleanup
58- self .action_space = spaces .MultiDiscrete ([3 , self .n_darts_selected ])
58+ # We have 4 actions, flip clockwise, flip counterclockwise , split, collapse, cleanup
59+ self .action_space = spaces .MultiDiscrete ([4 , self .n_darts_selected ])
5960
6061
6162 def reset (self , seed = None , options = None ):
@@ -96,11 +97,12 @@ def _get_info(self, terminated, valid_act, action, mesh_reward):
9697 "valid_action" : 1.0 if valid_action else 0.0 ,
9798 "invalid_topo" : 1.0 if not valid_topo else 0.0 ,
9899 "invalid_geo" : 1.0 if not valid_geo else 0.0 ,
99- "flip" : 1.0 if action [0 ]== Actions .FLIP .value else 0.0 ,
100+ "flip_cw" : 1.0 if action [0 ]== Actions .FLIP_CW .value else 0.0 ,
101+ "flip_cntcw" : 1.0 if action [0 ]== Actions .FLIP_CNTCW .value else 0.0 ,
100102 "split" : 1.0 if action [0 ]== Actions .SPLIT .value else 0.0 ,
101103 "collapse" : 1.0 if action [0 ]== Actions .COLLAPSE .value else 0.0 ,
102104 "cleanup" : 1.0 if action [0 ]== Actions .CLEANUP .value else 0.0 ,
103- "invalid_flip" : 1.0 if action [0 ]== Actions .FLIP .value and not valid_action else 0.0 ,
105+ "invalid_flip" : 1.0 if ( action [0 ]== Actions .FLIP_CW .value or action [ 0 ] == Actions . FLIP_CNTCW . value ) and not valid_action else 0.0 ,
104106 "invalid_split" : 1.0 if action [0 ]== Actions .SPLIT .value and not valid_action else 0.0 ,
105107 "invalid_collapse" : 1.0 if action [0 ]== Actions .COLLAPSE .value and not valid_action else 0.0 ,
106108 "invalid_cleanup" : 1.0 if action [0 ]== Actions .CLEANUP .value and not valid_action else 0.0 ,
@@ -125,8 +127,10 @@ def step(self, action: np.ndarray):
125127 n2 = d1 .get_node ()
126128 valid_action , valid_topo , valid_geo = False , False , False
127129
128- if action [0 ] == Actions .FLIP .value :
129- valid_action , valid_topo , valid_geo = flip_edge (self .mesh , n1 , n2 )
130+ if action [0 ] == Actions .FLIP_CW .value :
131+ valid_action , valid_topo , valid_geo = flip_edge_cw (self .mesh , n1 , n2 )
132+ elif action [0 ] == Actions .FLIP_CNTCW .value :
133+ valid_action , valid_topo , valid_geo = flip_edge_cntcw (self .mesh , n1 , n2 )
130134 elif action [0 ] == Actions .SPLIT .value :
131135 valid_action , valid_topo , valid_geo = split_edge (self .mesh , n1 , n2 )
132136 elif action [0 ] == Actions .COLLAPSE .value :
0 commit comments