55import numpy as np
66from collections import deque
77import logging
8-
8+ from rich import print
99logger = logging .getLogger ("ray" )
1010@ray .remote
1111class EpisodeStatistics :
@@ -47,18 +47,30 @@ def log_statistics(self, step: int, record_next_episode: bool):
4747 num_test_tasks = 0
4848 sum_discounted_reward = 0
4949 sum_episode_length = 0
50+ num_valid_episode_length = 0 # Track valid episode length count
51+
52+ # [Change log] Move the logging of step to the beginning by zhancun
53+ wandb_logger .log ({
54+ "episode_statistics/step" : step ,
55+ })
5056
5157 for task in self .sum_rewards_metrics .keys ():
5258 mean_sum_reward = self .sum_rewards_metrics [task ].compute ()
5359 mean_discounted_reward = self .discounted_rewards_metrics [task ].compute ()
5460 mean_episode_length = self .episode_lengths_metrics [task ].compute ()
5561
62+ # Log individual task metrics
63+ if not np .isnan (mean_sum_reward ):
64+ wandb_logger .log ({
65+ f"episode_statistics/{ task } /sum_reward" : mean_sum_reward ,
66+ f"episode_statistics/{ task } /discounted_reward" : mean_discounted_reward ,
67+ f"episode_statistics/{ task } /episode_length" : mean_episode_length ,
68+ })
69+ print (f"Task { task } - Sum Reward: { mean_sum_reward } , Discounted Reward: { mean_discounted_reward } , Episode Length: { mean_episode_length } " )
70+
5671 self .sum_rewards_metrics [task ].reset ()
5772 self .discounted_rewards_metrics [task ].reset ()
5873 self .episode_lengths_metrics [task ].reset ()
59- wandb_logger .log ({
60- "episode_statistics/step" : step ,
61- })
6274
6375 if not np .isnan (mean_sum_reward ) and "4train" in task :
6476 sum_train_reward += mean_sum_reward
@@ -67,22 +79,26 @@ def log_statistics(self, step: int, record_next_episode: bool):
6779 if not np .isnan (mean_sum_reward ) and "4test" in task :
6880 sum_test_reward += mean_sum_reward
6981 num_test_tasks += 1
70- sum_episode_length += mean_episode_length
82+
83+ # Only add episode length if it's not NaN
84+ if not np .isnan (mean_episode_length ):
85+ sum_episode_length += mean_episode_length
86+ num_valid_episode_length += 1
7187
7288 self .episode_info = {
7389 "steps" : step ,
7490 "episode_count" : self .acc_episode_count ,
7591 "mean_sum_reward" : sum_train_reward / num_train_tasks if num_train_tasks > 0 else 0 ,
7692 "mean_discounted_reward" : sum_discounted_reward / num_train_tasks if num_train_tasks > 0 else 0 ,
77- "mean_episode_length" : sum_episode_length / ( num_train_tasks + num_test_tasks ) if num_train_tasks + num_test_tasks > 0 else 0
93+ "mean_episode_length" : sum_episode_length / num_valid_episode_length if num_valid_episode_length > 0 else 0
7894 }
7995 wandb_logger .log ({
8096 "episode_statistics/steps" : step ,
8197 "episode_statistics/episode_count" : self .acc_episode_count ,
8298 "episode_statistics/mean_sum_reward" : sum_train_reward / num_train_tasks if num_train_tasks > 0 else 0 ,
8399 "episode_statistics/mean_test_sum_reward" : sum_test_reward / num_test_tasks if num_test_tasks > 0 else 0 ,
84100 "episode_statistics/mean_discounted_reward" : sum_discounted_reward / num_train_tasks if num_train_tasks > 0 else 0 ,
85- "episode_statistics/mean_episode_length" : sum_episode_length / ( num_train_tasks + num_test_tasks ) if num_train_tasks + num_test_tasks > 0 else 0
101+ "episode_statistics/mean_episode_length" : sum_episode_length / num_valid_episode_length if num_valid_episode_length > 0 else 0
86102 })
87103
88104 self .acc_episode_count = 0
0 commit comments