11from mesh_model .mesh_analysis .global_mesh_analysis import global_score
22import copy
33import random
4+ import json
45from tqdm import tqdm
56import numpy as np
67import torch
@@ -58,8 +59,8 @@ def select_action(self, observation, info):
5859 action = dist .sample ()
5960 action = action .tolist ()
6061 prob = pmf [action ]
61- action_dart = int (action / 3 )
62- action_type = action % 3
62+ action_dart = int (action / 4 )
63+ action_type = action % 4
6364 dart_id = info ["darts_selected" ][action_dart ]
6465 i = 0
6566 while not isValidAction (info ["mesh" ], dart_id , action_type ):
@@ -70,8 +71,8 @@ def select_action(self, observation, info):
7071 action = dist .sample ()
7172 action = action .tolist ()
7273 prob = pmf [action ]
73- action_dart = int (action / 3 )
74- action_type = action % 3
74+ action_dart = int (action / 4 )
75+ action_type = action % 4
7576 dart_id = info ["darts_selected" ][action_dart ]
7677 i += 1
7778 action_list = [action , dart_id , action_type ]
@@ -139,7 +140,7 @@ def learn(self, critic_loss):
139140class PPO :
140141 def __init__ (self , env , lr , gamma , nb_iterations , nb_episodes_per_iteration , nb_epochs , batch_size ):
141142 self .env = env
142- self .actor = Actor (env , 10 * 8 , 3 * 10 , lr = 0.0001 )
143+ self .actor = Actor (env , 10 * 8 , 4 * 10 , lr = 0.0001 )
143144 self .critic = Critic (8 * 10 , lr = 0.0001 )
144145 self .lr = lr
145146 self .gamma = gamma
@@ -165,16 +166,14 @@ def train(self, dataset):
165166 critic_loss = []
166167 actor_loss = []
167168 self .critic .optimizer .zero_grad ()
168- G = 0
169- for _ , (s , o , a , r , old_prob , next_o , done ) in enumerate (batch , 1 ):
169+ for _ , (s , o , a , r , G , old_prob , next_o , done ) in enumerate (batch , 1 ):
170170 o = torch .tensor (o .flatten (), dtype = torch .float32 )
171171 next_o = torch .tensor (next_o .flatten (), dtype = torch .float32 )
172172 value = self .critic (o )
173173 pmf = self .actor .forward (o )
174174 log_prob = torch .log (pmf [a [0 ]])
175175 next_value = torch .tensor (0.0 , dtype = torch .float32 ) if done else self .critic (next_o )
176176 delta = r + 0.9 * next_value - value
177- G = (r + 0.9 * G ) / 10
178177 _ , st , ideal_s , _ = global_score (s ) # Comparaison à l'état s et pas s+1 ?
179178 if st == ideal_s :
180179 continue
@@ -221,6 +220,7 @@ def learn(self, writer):
221220 ep_reward = 0
222221 ep_mesh_reward = 0
223222 ideal_reward = info ["mesh_ideal_rewards" ]
223+ G = 0
224224 done = False
225225 step = 0
226226 while step < 40 :
@@ -230,20 +230,21 @@ def learn(self, writer):
230230 if action is None :
231231 wins .append (0 )
232232 break
233- gym_action = [action [2 ],int (action [0 ]/ 3 )]
233+ gym_action = [action [2 ],int (action [0 ]/ 4 )]
234234 next_obs , reward , terminated , truncated , info = self .env .step (gym_action )
235235 ep_reward += reward
236236 ep_mesh_reward += info ["mesh_reward" ]
237+ G = info ["mesh_reward" ] + 0.9 * G
237238 if terminated :
238239 if truncated :
239240 wins .append (0 )
240- trajectory .append ((state , obs , action , reward , prob , next_obs , done ))
241+ trajectory .append ((state , obs , action , reward , G , prob , next_obs , done ))
241242 else :
242243 wins .append (1 )
243244 done = True
244- trajectory .append ((state , obs , action , reward , prob , next_obs , done ))
245+ trajectory .append ((state , obs , action , reward , G , prob , next_obs , done ))
245246 break
246- trajectory .append ((state , obs , action , reward , prob , next_obs , done ))
247+ trajectory .append ((state , obs , action , reward , G , prob , next_obs , done ))
247248 step += 1
248249 if len (trajectory ) != 0 :
249250 rewards .append (ep_reward )
@@ -252,7 +253,8 @@ def learn(self, writer):
252253 len_ep .append (len (trajectory ))
253254 nb_episodes += 1
254255 writer .add_scalar ("episode_reward" , ep_reward , nb_episodes )
255- writer .add_scalar ("normalized return" , (ep_reward / ideal_reward ), nb_episodes )
256+ writer .add_scalar ("episode_mesh_reward" , ep_mesh_reward , nb_episodes )
257+ writer .add_scalar ("normalized return" , (ep_mesh_reward / ideal_reward ), nb_episodes )
256258 writer .add_scalar ("len_episodes" , len (trajectory ), nb_episodes )
257259
258260 self .train (dataset )
@@ -263,4 +265,5 @@ def learn(self, writer):
263265 except NaNExceptionCritic :
264266 print ("NaN Exception on Critic Network" )
265267 return None , None , None , None
266- return self .actor , rewards , wins , len_ep
268+
269+ return self .actor , rewards , wins , len_ep , info ["observation_count" ]
0 commit comments