66import pathlib
77import random
88import time
9+ from collections import deque
910from dataclasses import dataclass
1011
1112import gymnasium as gym
@@ -181,6 +182,10 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
181182 next_obs = torch .Tensor (next_obs ).to (device )
182183 next_done = torch .zeros (args .num_envs ).to (device )
183184
185+ # episode reward stats, modified as Godot RL does not return this information in info (yet)
186+ episode_returns = deque (maxlen = 20 )
187+ accum_rewards = np .zeros (args .num_envs )
188+
184189 for iteration in range (1 , args .num_iterations + 1 ):
185190 # Annealing the rate if instructed to do so.
186191 if args .anneal_lr :
@@ -212,12 +217,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
212217 rewards [step ] = torch .tensor (reward ).to (device ).view (- 1 )
213218 next_obs , next_done = torch .Tensor (next_obs ).to (device ), torch .Tensor (next_done ).to (device )
214219
215- if "final_info" in infos :
216- for info in infos [ "final_info" ]:
217- if info and "episode" in info :
218- print ( f"global_step= { global_step } , episodic_return= { info [ 'episode' ][ 'r' ] } " )
219- writer . add_scalar ( "charts/episodic_return" , info [ "episode" ][ "r" ], global_step )
220- writer . add_scalar ( "charts/episodic_length" , info [ "episode" ][ "l" ], global_step )
220+ accum_rewards += np . array ( reward )
221+
222+ for i , d in enumerate ( next_done ) :
223+ if d :
224+ episode_returns . append ( accum_rewards [ i ] )
225+ accum_rewards [ i ] = 0
221226
222227 # Compute Q(lambda) targets
223228 with torch .no_grad ():
@@ -261,8 +266,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
261266
262267 writer .add_scalar ("losses/td_loss" , loss , global_step )
263268 writer .add_scalar ("losses/q_values" , old_val .mean ().item (), global_step )
264- print (f"SPS: { int (global_step / (time .time () - start_time ))} , Epsilon: { epsilon } " )
269+ print (f"SPS: { int (global_step / (time .time () - start_time ))} , Epsilon (rand action prob): { epsilon } " )
270+ if len (episode_returns ) > 0 :
271+ print (
272+ "Returns:" ,
273+ np .mean (np .array (episode_returns )),
274+ )
265275 writer .add_scalar ("charts/SPS" , int (global_step / (time .time () - start_time )), global_step )
276+ writer .add_scalar ("charts/episodic_return" , np .mean (np .array (episode_returns )), global_step )
266277
267278 envs .close ()
268279 writer .close ()
@@ -298,4 +309,4 @@ def forward(self, onnx_obs, state_ins):
298309 "output" : {0 : "batch_size" },
299310 "state_outs" : {0 : "batch_size" },
300311 },
301- )
312+ )
0 commit comments