Skip to content

Commit 305bda5

Browse files
committed
File reorganisation, for clean implementation of PPO on quad files.
1 parent a73713d commit 305bda5

File tree

126 files changed

+377
-88
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+377
-88
lines changed

actions/quadrangular_actions.py renamed to environment/actions/quadrangular_actions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from mesh_model.mesh_struct.mesh import Mesh
4-
from mesh_model.mesh_struct.mesh_elements import Dart, Node
5-
from mesh_model.mesh_analysis.mesh_analysis import adjacent_darts, degree, mesh_check
4+
from mesh_model.mesh_struct.mesh_elements import Node
5+
from mesh_model.mesh_analysis.global_mesh_analysis import adjacent_darts, degree, mesh_check
66
from mesh_model.mesh_analysis.quadmesh_analysis import isFlipOk, isCollapseOk, isSplitOk, isCleanupOk
77

88

@@ -198,3 +198,5 @@ def cleanup_edge(mesh: Mesh, n1: Node, n2: Node) -> True:
198198
mesh.del_node(n_from)
199199

200200
return mesh_check(mesh), topo, geo
201+
202+

environment/actions/smoothing.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from mesh_model.mesh_struct.mesh import Mesh
4+
from mesh_model.mesh_struct.mesh_elements import Node
5+
from mesh_model.mesh_analysis.global_mesh_analysis import adjacent_darts, on_boundary
6+
7+
def smoothing_mean(mesh: Mesh) -> True:
8+
for i in range (20):
9+
#plot_mesh(mesh)
10+
for i, n_info in enumerate (mesh.nodes, start=0):
11+
if n_info[2] >=0:
12+
node_to_smooth = Node(mesh, i)
13+
if not on_boundary(node_to_smooth):
14+
list_darts = adjacent_darts(node_to_smooth)
15+
sum_x = 0.0
16+
sum_y = 0.0
17+
nb_nodes = 0.0
18+
for d in list_darts:
19+
n = d.get_node()
20+
if n != node_to_smooth:
21+
sum_x += n.x()
22+
sum_y += n.y()
23+
nb_nodes += 1
24+
node_to_smooth.set_xy(sum_x/nb_nodes, sum_y/nb_nodes)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from mesh_model.mesh_struct.mesh import Mesh
44
from mesh_model.mesh_struct.mesh_elements import Dart, Node
5-
from mesh_model.mesh_analysis.mesh_analysis import mesh_check
5+
from mesh_model.mesh_analysis.global_mesh_analysis import mesh_check
66
from mesh_model.mesh_analysis.trimesh_analysis import isFlipOk, isCollapseOk, isSplitOk
77

88

environment/parameters/environment_config.json renamed to environment/environment_config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
"env_name": "Quadmesh-v0",
33
"mesh_size": 16,
44
"max_episode_steps": 30,
5-
"n_darts_selected": 10,
6-
"deep": 8,
5+
"n_darts_selected": 25,
6+
"deep": 6,
77
"action_restriction": false,
88
"with_degree_observation": false
99
}

environment/gymnasium_envs/quadmesh_env/envs/quadmesh.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
from enum import Enum
22
import gymnasium as gym
33
from gymnasium import spaces
4-
import pygame
54
import numpy as np
5+
import copy
66

77
from mesh_model.random_quadmesh import random_mesh
88
from mesh_model.mesh_struct.mesh_elements import Dart
99
from mesh_model.mesh_analysis.quadmesh_analysis import global_score, isTruncated
1010
from environment.gymnasium_envs.quadmesh_env.envs.mesh_conv import get_x
11-
from actions.quadrangular_actions import flip_edge, split_edge, collapse_edge, cleanup_edge
12-
13-
from view.window import Game
14-
from mesh_display import MeshDisplay
11+
from environment.actions.quadrangular_actions import flip_edge, split_edge, collapse_edge, cleanup_edge
1512

1613

1714
class Actions(Enum):
@@ -25,7 +22,12 @@ class QuadMeshEnv(gym.Env):
2522
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 60}
2623

2724
def __init__(self, mesh=None, n_darts_selected=20, deep=6, with_degree_obs=True, action_restriction=False, render_mode=None):
28-
self.mesh = mesh if mesh is not None else random_mesh()
25+
if mesh is not None:
26+
self.config = {"mesh": mesh}
27+
self.mesh = copy.deepcopy(mesh)
28+
else :
29+
self.config = {"mesh": None}
30+
self.mesh = random_mesh()
2931
self.mesh_size = len(self.mesh.nodes)
3032
self.nb_darts = len(self.mesh.dart_info)
3133
self._nodes_scores, self._mesh_score, self._ideal_score, self._nodes_adjacency = global_score(self.mesh)
@@ -40,7 +42,6 @@ def __init__(self, mesh=None, n_darts_selected=20, deep=6, with_degree_obs=True,
4042
self.nb_invalid_actions = 0
4143
self.darts_selected = [] # darts id observed
4244
deep = self.deep*2 if self.degree_observation else deep
43-
4445
self.observation_space = spaces.Box(
4546
low=-15, # nodes min degree : -15
4647
high=15, # nodes max degree : 15
@@ -59,6 +60,8 @@ def reset(self, seed=None, options=None):
5960
super().reset(seed=seed)
6061
if options is not None:
6162
self.mesh = options['mesh']
63+
elif self.config["mesh"] is not None:
64+
self.mesh = copy.deepcopy(self.config["mesh"])
6265
else:
6366
self.mesh = random_mesh()
6467
self.nb_darts = len(self.mesh.dart_info)

environment/gymnasium_envs/trimesh_flip_env/envs/trimesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mesh_model.mesh_struct.mesh_elements import Dart
88
from mesh_model.mesh_analysis.trimesh_analysis import global_score, isTruncated
99
from environment.gymnasium_envs.trimesh_flip_env.envs.mesh_conv import get_x
10-
from actions.triangular_actions import flip_edge
10+
from environment.actions.triangular_actions import flip_edge
1111

1212
from view.window import Game
1313
from mesh_display import MeshDisplay

environment/gymnasium_envs/trimesh_full_env/envs/trimesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mesh_model.mesh_struct.mesh_elements import Dart
99
from mesh_model.mesh_analysis.trimesh_analysis import global_score, isTruncated
1010
from environment.gymnasium_envs.trimesh_full_env.envs.mesh_conv import get_x
11-
from actions.triangular_actions import flip_edge, split_edge, collapse_edge
11+
from environment.actions.triangular_actions import flip_edge, split_edge, collapse_edge
1212

1313
from view.window import Game
1414
from mesh_display import MeshDisplay
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from trimesh_env.wrappers.clip_reward import ClipReward
2-
from trimesh_env.wrappers.discrete_actions import DiscreteActions
3-
from trimesh_env.wrappers.reacher_weighted_reward import ReacherRewardWrapper
4-
from trimesh_env.wrappers.relative_position import RelativePosition
1+
from environment.gymnasium_envs.trimesh_full_env.wrappers.clip_reward import ClipReward
2+
from environment.gymnasium_envs.trimesh_full_env.wrappers.discrete_actions import DiscreteActions
3+
from environment.gymnasium_envs.trimesh_full_env.wrappers.reacher_weighted_reward import ReacherRewardWrapper
4+
from environment.gymnasium_envs.trimesh_full_env.wrappers.relative_position import RelativePosition

environment/trimesh_env.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Any
21
import math
32
import numpy as np
4-
from mesh_model.mesh_analysis import global_score, find_template_opposite_node
3+
from mesh_model.mesh_analysis.trimesh_analysis import global_score, find_template_opposite_node
54
from mesh_model.mesh_struct.mesh_elements import Dart
65
from mesh_model.mesh_struct.mesh import Mesh
7-
from actions.triangular_actions import flip_edge, split_edge, collapse_edge
6+
from environment.actions.triangular_actions import flip_edge, split_edge, collapse_edge
87
from mesh_model.random_trimesh import random_flip_mesh, random_mesh
98

109
# possible actions
@@ -23,7 +22,7 @@ def __init__(self, mesh=None, mesh_size: int = None, max_steps: int = 50, feat:
2322
self.reward = 0
2423
self.steps = 0
2524
self.max_steps = max_steps
26-
self.nodes_scores, self.mesh_score, self.ideal_score = global_score(self.mesh)
25+
self.nodes_scores, self.mesh_score, self.ideal_score, _ = global_score(self.mesh)
2726
self.terminal = False
2827
self.feat = feat
2928
self.won = 0
@@ -34,7 +33,7 @@ def reset(self, mesh=None):
3433
self.terminal = False
3534
self.mesh = mesh if mesh is not None else random_mesh(self.mesh_size)
3635
self.size = len(self.mesh.dart_info)
37-
self.nodes_scores, self.mesh_score, self.ideal_score = global_score(self.mesh)
36+
self.nodes_scores, self.mesh_score, self.ideal_score, _ = global_score(self.mesh)
3837
self.won = 0
3938

4039
def step(self, action):
@@ -50,7 +49,7 @@ def step(self, action):
5049
elif action[2] == COLLAPSE:
5150
collapse_edge(self.mesh, n1, n2)
5251
self.steps += 1
53-
next_nodes_score, next_mesh_score, _ = global_score(self.mesh)
52+
next_nodes_score, next_mesh_score, _, _ = global_score(self.mesh)
5453
self.nodes_scores = next_nodes_score
5554
self.reward = (self.mesh_score - next_mesh_score) * 10
5655
if self.steps >= self.max_steps or next_mesh_score == self.ideal_score:

0 commit comments

Comments
 (0)