11from typing import Any
2+ import math
23import numpy as np
3- from model .mesh_analysis import global_score , isValidAction , find_template_opposite_node
4- from model .mesh_struct .mesh_elements import Dart
5- from model .mesh_struct .mesh import Mesh
6- from actions .triangular_actions import flip_edge
7- from model .random_trimesh import random_flip_mesh
4+ from mesh_model .mesh_analysis import global_score , find_template_opposite_node
5+ from mesh_model .mesh_struct .mesh_elements import Dart
6+ from mesh_model .mesh_struct .mesh import Mesh
7+ from actions .triangular_actions import flip_edge , split_edge , collapse_edge
8+ from mesh_model .random_trimesh import random_flip_mesh , random_mesh
89
910# possible actions
1011FLIP = 0
12+ SPLIT = 1
13+ COLLAPSE = 2
1114GLOBAL = 0
1215
1316
1417class TriMesh :
1518 def __init__ (self , mesh = None , mesh_size : int = None , max_steps : int = 50 , feat : int = 0 ):
1619 self .mesh = mesh if mesh is not None else random_flip_mesh (mesh_size )
17- self .mesh_size = len (self .mesh .nodes )
20+ self .mesh_size = len (self .mesh .active_nodes () )
1821 self .size = len (self .mesh .dart_info )
19- self .actions = np .array ([FLIP ])
22+ self .actions = np .array ([FLIP , SPLIT , COLLAPSE ])
2023 self .reward = 0
2124 self .steps = 0
2225 self .max_steps = max_steps
23- self .nodes_scores = global_score (self .mesh )[0 ]
24- self .ideal_score = global_score (self .mesh )[2 ]
26+ self .nodes_scores , self .mesh_score , self .ideal_score = global_score (self .mesh )
2527 self .terminal = False
2628 self .feat = feat
2729 self .won = 0
@@ -30,30 +32,34 @@ def reset(self, mesh=None):
3032 self .reward = 0
3133 self .steps = 0
3234 self .terminal = False
33- self .mesh = mesh if mesh is not None else random_flip_mesh (self .mesh_size )
35+ self .mesh = mesh if mesh is not None else random_mesh (self .mesh_size )
3436 self .size = len (self .mesh .dart_info )
35- self .nodes_scores = global_score (self .mesh )[0 ]
36- self .ideal_score = global_score (self .mesh )[2 ]
37+ self .nodes_scores , self .mesh_score , self .ideal_score = global_score (self .mesh )
3738 self .won = 0
3839
3940 def step (self , action ):
4041 dart_id = action [1 ]
41- _ , mesh_score , mesh_ideal_score = global_score (self .mesh )
4242 d = Dart (self .mesh , dart_id )
4343 d1 = d .get_beta (1 )
4444 n1 = d .get_node ()
4545 n2 = d1 .get_node ()
46- flip_edge (self .mesh , n1 , n2 )
46+ if action [2 ] == FLIP :
47+ flip_edge (self .mesh , n1 , n2 )
48+ elif action [2 ] == SPLIT :
49+ split_edge (self .mesh , n1 , n2 )
50+ elif action [2 ] == COLLAPSE :
51+ collapse_edge (self .mesh , n1 , n2 )
4752 self .steps += 1
4853 next_nodes_score , next_mesh_score , _ = global_score (self .mesh )
4954 self .nodes_scores = next_nodes_score
50- self .reward = (mesh_score - next_mesh_score )* 10
51- if self .steps >= self .max_steps or next_mesh_score == mesh_ideal_score :
52- if next_mesh_score == mesh_ideal_score :
55+ self .reward = (self . mesh_score - next_mesh_score ) * 10
56+ if self .steps >= self .max_steps or next_mesh_score == self . ideal_score :
57+ if next_mesh_score == self . ideal_score :
5358 self .won = True
5459 self .terminal = True
60+ self .nodes_scores , self .mesh_score = next_nodes_score , next_mesh_score
5561
56- def get_x (self , s : Mesh , a : int ) -> tuple [ Any , list [ int | list [ int ]]] :
62+ def get_x (self , s : Mesh , a : int ):
5763 """
5864 Get the feature vector of the state-action pair
5965 :param s: the state
@@ -66,19 +72,39 @@ def get_x(self, s: Mesh, a: int) -> tuple[Any, list[int | list[int]]]:
6672 return get_x_global_4 (self , s )
6773
6874
69- def get_x_global_4 (env , state : Mesh ) -> tuple [ Any , list [ int | list [ int ]]] :
75+ def get_x_global_4 (env , state : Mesh ):
7076 """
7177 Get the feature vector of the state.
7278 :param state: the state
7379 :param env: The environment
7480 :return: the feature vector
7581 """
7682 mesh = state
83+ template = get_template_2 (mesh )
84+ darts_to_delete = []
85+ darts_id = []
86+
87+ for i , d_info in enumerate (mesh .active_darts ()):
88+ d_id = d_info [0 ]
89+ if d_info [2 ] == - 1 : #test the validity of all action type
90+ darts_to_delete .append (i )
91+ else :
92+ darts_id .append (d_id )
93+ valid_template = np .delete (template , darts_to_delete , axis = 0 )
94+ score_sum = np .sum (np .abs (valid_template ), axis = 1 )
95+ indices_top_10 = np .argsort (score_sum )[- 5 :][::- 1 ]
96+ valid_dart_ids = [darts_id [i ] for i in indices_top_10 ]
97+ X = valid_template [indices_top_10 , :]
98+ X = X .flatten ()
99+ return X , valid_dart_ids
100+
101+
102+ def get_template_2 (mesh : Mesh ):
77103 nodes_scores = global_score (mesh )[0 ]
78- size = len (mesh .dart_info )
104+ size = len (mesh .active_darts () )
79105 template = np .zeros ((size , 6 ))
80106
81- for d_info in mesh .dart_info :
107+ for i , d_info in enumerate ( mesh .active_darts ()) :
82108
83109 d = Dart (mesh , d_info [0 ])
84110 A = d .get_node ()
@@ -87,35 +113,21 @@ def get_x_global_4(env, state: Mesh) -> tuple[Any, list[int | list[int]]]:
87113 d11 = d1 .get_beta (1 )
88114 C = d11 .get_node ()
89115
90- #Template niveau 1
91- template [d_info [ 0 ] , 0 ] = nodes_scores [C .id ]
92- template [d_info [ 0 ] , 1 ] = nodes_scores [A .id ]
93- template [d_info [ 0 ] , 2 ] = nodes_scores [B .id ]
116+ # Template niveau 1
117+ template [i , 0 ] = nodes_scores [C .id ] if not math . isnan ( nodes_scores [ C . id ]) else 0
118+ template [i , 1 ] = nodes_scores [A .id ] if not math . isnan ( nodes_scores [ A . id ]) else 0
119+ template [i , 2 ] = nodes_scores [B .id ] if not math . isnan ( nodes_scores [ B . id ]) else 0
94120
95- #template niveau 2
121+ # template niveau 2
96122
97123 n_id = find_template_opposite_node (d )
98- if n_id is not None :
99- template [d_info [ 0 ] , 3 ] = nodes_scores [n_id ]
124+ if n_id is not None and not math . isnan ( nodes_scores [ n_id ]) :
125+ template [i , 3 ] = nodes_scores [n_id ]
100126 n_id = find_template_opposite_node (d1 )
101- if n_id is not None :
102- template [d_info [ 0 ] , 4 ] = nodes_scores [n_id ]
127+ if n_id is not None and not math . isnan ( nodes_scores [ n_id ]) :
128+ template [i , 4 ] = nodes_scores [n_id ]
103129 n_id = find_template_opposite_node (d11 )
104- if n_id is not None :
105- template [d_info [0 ], 5 ] = nodes_scores [n_id ]
106-
107- dart_to_delete = []
108- dart_ids = []
109- for i in range (size ):
110- d = Dart (mesh , i )
111- if not isValidAction (mesh , d .id ):
112- dart_to_delete .append (i )
113- else :
114- dart_ids .append (i )
115- valid_template = np .delete (template , dart_to_delete , axis = 0 )
116- score_sum = np .sum (np .abs (valid_template ), axis = 1 )
117- indices_top_10 = np .argsort (score_sum )[- 5 :][::- 1 ]
118- valid_dart_ids = [dart_ids [i ] for i in indices_top_10 ]
119- X = valid_template [indices_top_10 , :]
120- X = X .flatten ()
121- return X , valid_dart_ids
130+ if n_id is not None and not math .isnan (nodes_scores [n_id ]):
131+ template [i , 5 ] = nodes_scores [n_id ]
132+
133+ return template
0 commit comments