@@ -84,12 +84,11 @@ def test_trainer(self):
8484 self .config .buffer .explorer_input .taskset .task_selector = TaskSelectorConfig (
8585 selector_type = "shuffle" , seed = 42
8686 )
87- self .config .buffer .explorer_input .eval_tasksets .append (
88- get_unittest_dataset_config ("countdown" , "test" )
89- )
90- self .config .buffer .explorer_input .eval_tasksets .append (
91- get_unittest_dataset_config ("copy_countdown" , "test" )
92- )
87+ eval_tasksets = self .config .buffer .explorer_input .eval_tasksets
88+ eval_tasksets .append (get_unittest_dataset_config ("countdown" , "test" ))
89+ eval_tasksets .append (get_unittest_dataset_config ("copy_countdown" , "test" ))
90+ eval_tasksets [0 ].repeat_times = 4
91+ eval_tasksets [1 ].repeat_times = 4
9392 self .config .trainer .save_interval = 4
9493 self .config .check_and_update ()
9594 _trainer_config = self .config .trainer .trainer_config
@@ -148,14 +147,15 @@ def test_trainer(self):
148147 bench (self .config )
149148 parser = TensorBoardParser (os .path .join (self .config .monitor .cache_dir , "tensorboard" ))
150149 for prefix in ["eval" , "bench" ]:
151- countdown_metrics = parser .metric_list (f"{ prefix } /countdown" )
152- copy_countdown_metrics = parser .metric_list (f"{ prefix } /copy_countdown" )
153- self .assertTrue (len (countdown_metrics ) > 0 )
154- self .assertTrue (len (copy_countdown_metrics ) > 0 )
155- countdown_metric_steps = parser .metric_steps (countdown_metrics [0 ])
156- countdown_copy_metric_steps = parser .metric_steps (copy_countdown_metrics [0 ])
157- self .assertEqual ([0 , 4 , 8 ], countdown_metric_steps )
158- self .assertEqual ([0 , 4 , 8 ], countdown_copy_metric_steps )
150+ for taskset_name in ["countdown" , "copy_countdown" ]:
151+ metrics = parser .metric_list (f"{ prefix } /{ taskset_name } " )
152+ self .assertTrue (len (metrics ) > 0 )
153+ for eval_stats in ["mean" , "best" , "worst" ]:
154+ for k in [2 , 4 ]:
155+ for stats in ["mean" , "std" ]:
156+ metric_name = f"{ prefix } /{ taskset_name } /score/{ eval_stats } @{ k } /{ stats } "
157+ metric_steps = parser .metric_steps (metric_name )
158+ self .assertEqual (metric_steps , [0 , 4 , 8 ])
159159
160160 def tearDown (self ):
161161 # remove dir only when the test passed
@@ -969,6 +969,7 @@ def test_trainer(self):
969969 self .config .buffer .explorer_input .eval_tasksets .append (
970970 get_unittest_dataset_config ("gsm8k" , "test" )
971971 )
972+ self .config .buffer .explorer_input .eval_tasksets [0 ].repeat_times = 8
972973 self .config .model .model_path = get_model_path ()
973974 self .config .algorithm .algorithm_type = "grpo"
974975 self .config .algorithm .advantage_fn = "grpo"
@@ -1019,8 +1020,12 @@ def test_trainer(self):
10191020 for prefix in ["eval" , "bench" ]:
10201021 gsm8k_metrics = parser .metric_list (f"{ prefix } /gsm8k" )
10211022 self .assertTrue (len (gsm8k_metrics ) > 0 )
1022- gsm8k_metric_steps = parser .metric_steps (gsm8k_metrics [0 ])
1023- self .assertEqual ([0 , 2 ], gsm8k_metric_steps )
1023+ for eval_stats in ["mean" , "best" , "worst" ]:
1024+ for k in [2 , 4 , 8 ]:
1025+ for stats in ["mean" , "std" ]:
1026+ metric_name = f"{ prefix } /gsm8k/accuracy/{ eval_stats } @{ k } /{ stats } "
1027+ metric_steps = parser .metric_steps (metric_name )
1028+ self .assertEqual (metric_steps , [0 , 2 ])
10241029
10251030 def tearDown (self ):
10261031 shutil .rmtree (self .config .checkpoint_job_dir )
0 commit comments