Skip to content

Commit a73713d

Browse files
committed
Functional quad environment with actions flip, split, collapse.
1 parent 4eaa01a commit a73713d

File tree

23 files changed

+492
-217
lines changed

23 files changed

+492
-217
lines changed

Training2/__init__.py

Whitespace-only changes.

train.py renamed to Training2/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def train():
3535
if rewards is not None:
3636
plot_training_results(rewards, wins, steps)
3737

38-
torch.save(actor.state_dict(), 'policy_saved/actor_network.pth')
38+
torch.save(actor.state_dict(), '../policy_saved/actor_network.pth')
3939
avg_steps, avg_wins, avg_rewards, final_meshes = testPolicy(actor, 5, dataset, 60)
4040

4141
if rewards is not None:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def _on_training_end(self) -> None:
9494
self.logger.dump(step=0)
9595

9696

97-
with open("model_RL/parameters/ppo_config.json", "r") as f:
97+
with open("../model_RL/parameters/ppo_config.json", "r") as f:
9898
ppo_config = json.load(f)
99-
with open("environment/parameters/environment_config.json", "r") as f:
99+
with open("../environment/parameters/environment_config.json", "r") as f:
100100
env_config = json.load(f)
101101

102102
# Create log dir

Training2/train_quadmesh_SB3.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import json
5+
6+
7+
import mesh_model.random_quadmesh as QM
8+
from environment.gymnasium_envs.quadmesh_env import QuadMeshEnv
9+
from plots.mesh_plotter import dataset_plt
10+
from exploit_SB3_policy import testPolicy
11+
from stable_baselines3 import PPO,SAC
12+
from stable_baselines3.common.env_checker import check_env
13+
from stable_baselines3.common.callbacks import BaseCallback
14+
from stable_baselines3.common.logger import Figure
15+
16+
import gymnasium as gym
17+
18+
class TensorboardCallback(BaseCallback):
19+
"""
20+
Custom callback for plotting additional values in tensorboard.
21+
"""
22+
def __init__(self, model, verbose=0):
23+
super().__init__(verbose)
24+
self.model = model
25+
self.episode_rewards = []
26+
self.mesh_reward = 0
27+
self.current_episode_reward = 0
28+
self.episode_count = 0
29+
self.current_episode_length = 0
30+
self.actions_info = {
31+
"episode_valid_actions": 0,
32+
"episode_invalid_topo": 0,
33+
"episode_invalid_geo": 0,
34+
"nb_flip" : 0,
35+
"nb_split": 0,
36+
"nb_collapse": 0,
37+
"nb_cleanup": 0,
38+
"nb_invalid_flip": 0,
39+
"nb_invalid_split": 0,
40+
"nb_invalid_collapse": 0,
41+
"nb_invalid_cleanup": 0,
42+
}
43+
self.final_distance = 0
44+
self.normalized_return = 0
45+
46+
def _on_training_start(self) -> None:
47+
"""
48+
Record PPO parameters and environment configuration at the training start.
49+
"""
50+
self.logger.record("parameters/ppo", f"<pre>{json.dumps(ppo_config, indent=4)}</pre>")
51+
self.logger.record("parameters/env", f"<pre>{json.dumps(env_config, indent=4)}</pre>")
52+
self.logger.dump(step=0)
53+
54+
def _on_step(self) -> bool:
55+
"""
56+
Record different learning variables to monitor
57+
"""
58+
self.current_episode_reward += self.locals["rewards"][0]
59+
self.current_episode_length += 1
60+
61+
self.actions_info["episode_valid_actions"] += self.locals["infos"][0].get("valid_action", 0.0)
62+
self.actions_info["episode_invalid_topo"] += self.locals["infos"][0].get("invalid_topo", 0.0)
63+
self.actions_info["episode_invalid_geo"] += self.locals["infos"][0].get("invalid_geo", 0.0)
64+
self.actions_info["nb_flip"] += self.locals["infos"][0].get("flip", 0.0)
65+
self.actions_info["nb_split"] += self.locals["infos"][0].get("split", 0.0)
66+
self.actions_info["nb_collapse"] += self.locals["infos"][0].get("collapse", 0.0)
67+
self.actions_info["nb_cleanup"] += self.locals["infos"][0].get("cleanup", 0.0)
68+
self.actions_info["nb_invalid_flip"] += self.locals["infos"][0].get("invalid_flip", 0.0)
69+
self.actions_info["nb_invalid_split"] += self.locals["infos"][0].get("invalid_split", 0.0)
70+
self.actions_info["nb_invalid_collapse"] += self.locals["infos"][0].get("invalid_collapse", 0.0)
71+
self.actions_info["nb_invalid_cleanup"] += self.locals["infos"][0].get("invalid_cleanup", 0.0)
72+
73+
self.mesh_reward += self.locals["infos"][0].get("mesh_reward", 0.0)
74+
75+
# When the episode is over
76+
if self.locals["dones"][0]:
77+
self.episode_rewards.append(self.current_episode_reward) # global rewards obtained during the episode
78+
mesh_ideal_reward = self.locals["infos"][0].get("mesh_ideal_rewards", 0.0) # maximum achievable reward
79+
if mesh_ideal_reward > 0:
80+
self.normalized_return = self.mesh_reward/ mesh_ideal_reward
81+
else:
82+
self.normalized_return = 0
83+
84+
self.final_distance = self.locals["infos"][0].get("distance", 0.0)
85+
self.logger.record("final_distance", self.final_distance)
86+
self.logger.record("valid_actions", self.actions_info["episode_valid_actions"]*100/self.current_episode_length if self.current_episode_length > 0 else 0)
87+
self.logger.record("n_invalid_topo", self.actions_info["episode_invalid_topo"])
88+
self.logger.record("n_invalid_geo", self.actions_info["episode_invalid_geo"])
89+
self.logger.record("nb_flip", self.actions_info["nb_flip"])
90+
self.logger.record("nb_split", self.actions_info["nb_split"])
91+
self.logger.record("nb_collapse", self.actions_info["nb_collapse"])
92+
self.logger.record("nb_cleanup", self.actions_info["nb_cleanup"])
93+
self.logger.record("invalid_flip", self.actions_info["nb_invalid_flip"]*100/self.actions_info["nb_flip"] if self.actions_info["nb_flip"] > 0 else 0)
94+
self.logger.record("invalid_split", self.actions_info["nb_invalid_split"]*100/self.actions_info["nb_split"] if self.actions_info["nb_split"] > 0 else 0)
95+
self.logger.record("invalid_collapse", self.actions_info["nb_invalid_collapse"]*100/self.actions_info["nb_collapse"]if self.actions_info["nb_collapse"] > 0 else 0)
96+
self.logger.record("invalid_cleanup", self.actions_info["nb_invalid_cleanup"]*100/self.actions_info["nb_cleanup"]if self.actions_info["nb_cleanup"] > 0 else 0)
97+
self.logger.record("episode_mesh_reward", self.mesh_reward)
98+
self.logger.record("episode_reward", self.current_episode_reward)
99+
self.logger.record("normalized_return", self.normalized_return)
100+
self.logger.record("episode_length", self.current_episode_length)
101+
102+
is_success = self.locals["infos"][0].get("is_success", 0.0) # Default value: 0.0
103+
self.logger.record("episode_success", is_success)
104+
105+
self.logger.dump(step=self.episode_count)
106+
self.current_episode_reward = 0 # resets global episode reward
107+
self.mesh_reward = 0 # resets mesh episode reward
108+
self.current_episode_length = 0
109+
#reset actions info
110+
for key in self.actions_info.keys():
111+
self.actions_info[key] = 0
112+
self.episode_count += 1 # Increment episode counter
113+
114+
return True
115+
116+
def _on_training_end(self) -> None:
117+
"""
118+
Records policy evaluation results : before and after dataset images
119+
"""
120+
dataset = [QM.random_mesh() for _ in range(9)] # dataset of 9 meshes of size 30
121+
before = dataset_plt(dataset) # plot the datasat as image
122+
length, wins, rewards, normalized_return, final_meshes = testPolicy(self.model, 10, env_config, dataset) # test model policy on the dataset
123+
after = dataset_plt(final_meshes)
124+
self.logger.record("figures/before", Figure(before, close=True), exclude=("stdout", "log"))
125+
self.logger.record("figures/after", Figure(after, close=True), exclude=("stdout", "log"))
126+
self.logger.dump(step=0)
127+
128+
129+
with open("../model_RL/parameters/ppo_config.json", "r") as f:
130+
ppo_config = json.load(f)
131+
with open("../environment/parameters/environment_config.json", "r") as f:
132+
env_config = json.load(f)
133+
134+
# Create log dir
135+
log_dir = ppo_config["tensorboard_log"]
136+
os.makedirs(log_dir, exist_ok=True)
137+
138+
# Create the environment
139+
env = gym.make(
140+
env_config["env_name"],
141+
max_episode_steps=env_config["max_episode_steps"],
142+
n_darts_selected=env_config["n_darts_selected"],
143+
deep= env_config["deep"],
144+
action_restriction=env_config["action_restriction"],
145+
with_degree_obs=env_config["with_degree_observation"]
146+
)
147+
148+
check_env(env, warn=True)
149+
150+
model = PPO(
151+
policy=ppo_config["policy"],
152+
env=env,
153+
n_steps=ppo_config["n_steps"],
154+
n_epochs=ppo_config["n_epochs"],
155+
batch_size=ppo_config["batch_size"],
156+
learning_rate=ppo_config["learning_rate"],
157+
gamma=ppo_config["gamma"],
158+
verbose=ppo_config["verbose"],
159+
tensorboard_log=log_dir
160+
)
161+
162+
print("-----------Starting learning-----------")
163+
model.learn(total_timesteps=ppo_config["total_timesteps"], callback=TensorboardCallback(model))
164+
print("-----------Learning ended------------")
165+
model.save("policy_saved/quad/test3")

train_SB3.py renamed to Training2/train_trimesh_SB3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def _on_training_end(self) -> None:
120120
self.logger.dump(step=0)
121121

122122

123-
with open("model_RL/parameters/ppo_config.json", "r") as f:
123+
with open("../model_RL/parameters/ppo_config.json", "r") as f:
124124
ppo_config = json.load(f)
125-
with open("environment/parameters/environment_config.json", "r") as f:
125+
with open("../environment/parameters/environment_config.json", "r") as f:
126126
env_config = json.load(f)
127127

128128
# Create log dir
@@ -157,4 +157,4 @@ def _on_training_end(self) -> None:
157157
print("-----------Starting learning-----------")
158158
model.learn(total_timesteps=ppo_config["total_timesteps"], callback=TensorboardCallback(model))
159159
print("-----------Learning ended------------")
160-
model.save("policy_saved/final/final-PPO-4")
160+
model.save("policy_saved/test/test-PPO-4")

actions/quadrangular_actions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,6 @@ def cleanup_edge(mesh: Mesh, n1: Node, n2: Node) -> True:
190190

191191
mesh.del_quad(d, d1, d11, d111, f)
192192

193-
194-
195193
adj_darts = adjacent_darts(n_from)
196194

197195
for d in adj_darts:

environment/gymnasium_envs/quadmesh_env/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
id="Quadmesh-v0",
66
entry_point="environment.gymnasium_envs.quadmesh_env.envs:QuadMeshEnv",
77
max_episode_steps=100,
8-
kwargs={"mesh": None, "mesh_size": 30, "n_darts_selected": 20, "deep": 6, "with_degree_obs": True, "action_restriction": False},
8+
kwargs={"mesh": None, "n_darts_selected": 20, "deep": 6, "with_degree_obs": True, "action_restriction": False},
99
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from environment.gymnasium_envs.trimesh_full_env.envs.trimesh import TriMeshEnvFull
1+
from environment.gymnasium_envs.quadmesh_env.envs.quadmesh import QuadMeshEnv

0 commit comments

Comments
 (0)