@@ -43,6 +43,7 @@ def setUp(self):
4343 self .config .checkpoint_root_dir = get_checkpoint_path ()
4444 self .config .synchronizer .sync_interval = 2
4545 self .config .explorer .eval_interval = 4
46+ self .config .monitor .detailed_stats = False
4647
4748
4849class TestExplorerCountdownEval (BaseExplorerCase ):
@@ -70,21 +71,48 @@ def test_explorer(self):
7071 self .assertEqual (parser .metric_max_step (eval_metrics [0 ]), 8 )
7172 for eval_taskset , k_list in zip (eval_tasksets , [[1 ], [2 , 4 , 6 ], [2 , 4 , 8 , 10 ]]):
7273 metric_name = "score" if eval_taskset .name == "countdown" else "accuracy"
73- for eval_stats in ["mean" , "std" ]:
74- k = k_list [- 1 ]
74+ repeat_times = k_list [- 1 ]
75+ expected_stat_suffixes = [f"mean@{ repeat_times } " , f"std@{ repeat_times } " ]
76+ for k in k_list :
77+ if k == 1 :
78+ continue
79+ expected_stat_suffixes .extend ([f"best@{ k } " , f"worst@{ k } " ])
80+ # only return the mean of the column
81+ for stat_suffix in expected_stat_suffixes :
7582 self .assertIn (
76- f"eval/{ eval_taskset .name } /{ metric_name } /{ eval_stats } @{ k } " ,
83+ f"eval/{ eval_taskset .name } /{ metric_name } /{ stat_suffix } " ,
84+ eval_metrics ,
85+ )
86+
87+
88+ class TestExplorerEvalDetailedStats (BaseExplorerCase ):
89+ def test_explorer (self ):
90+ self .config .buffer .explorer_input .taskset = get_unittest_dataset_config ("countdown" )
91+ self .config .monitor .detailed_stats = True
92+ eval_taskset = get_unittest_dataset_config ("eval_short" )
93+ eval_taskset .repeat_times = 6
94+ self .config .buffer .explorer_input .eval_tasksets = [eval_taskset ]
95+ self .config .name = f"explore-eval-{ datetime .now ().strftime ('%Y%m%d%H%M%S' )} "
96+ self .config .check_and_update ()
97+ explore (self .config )
98+ parser = TensorBoardParser (os .path .join (self .config .monitor .cache_dir , "tensorboard" ))
99+ rollout_metrics = parser .metric_list ("rollout" )
100+ self .assertTrue (len (rollout_metrics ) > 0 )
101+ eval_metrics = parser .metric_list ("eval" )
102+ self .assertTrue (len (eval_metrics ) > 0 )
103+ self .assertEqual (parser .metric_max_step (rollout_metrics [0 ]), 8 )
104+ self .assertEqual (parser .metric_max_step (eval_metrics [0 ]), 8 )
105+ metric_name , repeat_times , k_list = "accuracy" , 6 , [2 , 4 , 6 ]
106+ expected_stat_suffixes = [f"mean@{ repeat_times } " , f"std@{ repeat_times } " ]
107+ for k in k_list : # k_list does not include 1
108+ expected_stat_suffixes .extend ([f"best@{ k } " , f"worst@{ k } " ])
109+ # test detailed stats
110+ for stat_suffix in expected_stat_suffixes :
111+ for stats in ["mean" , "std" , "max" , "min" ]:
112+ self .assertIn (
113+ f"eval/{ eval_taskset .name } /{ metric_name } /{ stat_suffix } /{ stats } " ,
77114 eval_metrics ,
78115 )
79- for eval_stats in ["best" , "worst" ]:
80- for k in k_list :
81- if k == 1 :
82- continue
83- for stats in ["mean" , "std" ]:
84- self .assertIn (
85- f"eval/{ eval_taskset .name } /{ metric_name } /{ eval_stats } @{ k } /{ stats } " ,
86- eval_metrics ,
87- )
88116
89117
90118class TestExplorerGSM8KRULERNoEval (BaseExplorerCase ):
0 commit comments