Skip to content

Commit 7e9db50

Browse files
authored
Merge pull request #64 from LIHPC-Computational-Geometry/63-improve-training-scripts
63 improve training scripts
2 parents 1865d67 + 05d5dc2 commit 7e9db50

File tree

16 files changed

+947
-195
lines changed

16 files changed

+947
-195
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"env_name": "Quadmesh-v0",
3+
"mesh_size": 16,
4+
"max_episode_steps": 20,
5+
"n_darts_selected": 10,
6+
"deep": 8,
7+
"action_restriction": false,
8+
"with_degree_observation": false
9+
}

environment/gymnasium_envs/quadmesh_env/envs/mesh_conv.py

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
def get_x(state: Mesh, n_darts_selected: int, deep :int, degree: bool, restricted:bool, nodes_scores: list[int], nodes_adjacency: list[int]):
88
mesh = state
99
if degree:
10-
template, darts_id = get_template_deg(mesh, deep, nodes_scores, nodes_adjacency)
10+
deep = int(deep / 2)
11+
template, darts_id = get_template_boundary(mesh, deep, nodes_scores, nodes_adjacency)
1112
else:
1213
template, darts_id = get_template(mesh, deep, nodes_scores)
1314

@@ -80,6 +81,8 @@ def get_template(mesh: Mesh, deep: int, nodes_scores):
8081
template[n_darts - 1, len(E)-1] = nodes_scores[N2.id]
8182
else:
8283
E.extend([None,None])
84+
#template[n_darts - 1, len(E) - 1] = -500 # dummy vertices are assigned to -500
85+
#template[n_darts - 1, len(E) - 2] = -500 # dummy vertices are assigned to -500
8386

8487
template = template[:n_darts, :]
8588

@@ -124,19 +127,85 @@ def get_template_deg(mesh: Mesh, deep: int, nodes_scores, nodes_adjacency):
124127
if deep > 4:
125128
while len(E) < deep:
126129
df = F.pop(0)
127-
df1 = df.get_beta(1)
128-
df11 = df1.get_beta(1)
129-
df111 = df11.get_beta(1)
130-
F.append(df1)
131-
F.append(df11)
132-
F.append(df111)
133-
N1, N2 = df11.get_node(), df111.get_node()
134-
E.append(N1)
135-
template[n_darts - 1, len(E)] = nodes_scores[N1.id]
136-
template[n_darts - 1, deep + len(E)] = nodes_adjacency[N1.id]
137-
E.append(N2)
138-
template[n_darts - 1, len(E)] = nodes_scores[N2.id]
139-
template[n_darts - 1, deep + len(E)] = nodes_adjacency[N2.id]
130+
if df is not None:
131+
df1 = df.get_beta(1)
132+
df11 = df1.get_beta(1)
133+
df111 = df11.get_beta(1)
134+
F.append(df1)
135+
F.append(df11)
136+
F.append(df111)
137+
N1, N2 = df11.get_node(), df111.get_node()
138+
E.append(N1)
139+
template[n_darts-1, len(E)-1] = nodes_scores[N1.id]
140+
template[n_darts-1, deep + len(E)-1] = nodes_adjacency[N1.id]
141+
E.append(N2)
142+
template[n_darts - 1, len(E)-1] = nodes_scores[N2.id]
143+
template[n_darts - 1, deep + len(E)-1] = nodes_adjacency[N2.id]
144+
else:
145+
E.extend([None,None])
146+
#template[n_darts - 1, len(E) - 1] = -500 # dummy vertices are assigned to -500
147+
#template[n_darts - 1, len(E) - 2] = -500 # dummy vertices are assigned to -500
148+
149+
template = template[:n_darts, :]
150+
return template, dart_ids
151+
152+
def get_template_boundary(mesh: Mesh, deep: int, nodes_scores, nodes_adjacency):
153+
size = len(mesh.dart_info)
154+
template = np.zeros((size, deep*2), dtype=np.int64)
155+
dart_ids = []
156+
n_darts = 0
157+
158+
for d_info in mesh.active_darts():
159+
n_darts += 1
160+
d_id = d_info[0]
161+
dart_ids.append(d_id)
162+
d = Dart(mesh, d_id)
163+
A = d.get_node()
164+
d1 = d.get_beta(1)
165+
B = d1.get_node()
166+
d11 = d1.get_beta(1)
167+
C = d11.get_node()
168+
d111 = d11.get_beta(1)
169+
D = d111.get_node()
170+
171+
# Template niveau 1
172+
template[n_darts - 1, 0] = nodes_scores[A.id]
173+
template[n_darts - 1, deep] = 1
174+
template[n_darts - 1, 1] = nodes_scores[B.id]
175+
template[n_darts - 1, deep+1] = 1
176+
template[n_darts - 1, 2] = nodes_scores[C.id]
177+
template[n_darts - 1, deep+2] = 1
178+
template[n_darts - 1, 3] = nodes_scores[D.id]
179+
template[n_darts - 1, deep + 3] = 1
180+
181+
E = [A, B, C, D]
182+
deep_captured = len(E)
183+
d2 = d.get_beta(2)
184+
d12 = d1.get_beta(2)
185+
d112 = d11.get_beta(2)
186+
d1112 = d111.get_beta(2)
187+
F = [d2, d12, d112, d1112]
188+
if deep > 4:
189+
while len(E) < deep:
190+
df = F.pop(0)
191+
if df is not None:
192+
df1 = df.get_beta(1)
193+
df11 = df1.get_beta(1)
194+
df111 = df11.get_beta(1)
195+
F.append(df1)
196+
F.append(df11)
197+
F.append(df111)
198+
N1, N2 = df11.get_node(), df111.get_node()
199+
E.append(N1)
200+
template[n_darts-1, len(E)-1] = nodes_scores[N1.id]
201+
template[n_darts-1, deep + len(E)-1] = 1
202+
E.append(N2)
203+
template[n_darts - 1, len(E)-1] = nodes_scores[N2.id]
204+
template[n_darts - 1, deep + len(E)-1] = 1
205+
else:
206+
E.extend([None,None])
207+
#template[n_darts - 1, len(E) - 1] = -500 # dummy vertices are assigned to -500
208+
#template[n_darts - 1, len(E) - 2] = -500 # dummy vertices are assigned to -500
140209

141210
template = template[:n_darts, :]
142211
return template, dart_ids

0 commit comments

Comments
 (0)