@@ -90,6 +90,8 @@ def test_trainer(self):
9090 self .config .buffer .explorer_input .eval_tasksets .append (
9191 get_unittest_dataset_config ("copy_countdown" , "test" )
9292 )
93+ self .config .buffer .explorer_input .eval_tasksets [0 ].eval_at_k = [1 , 2 ]
94+ self .config .buffer .explorer_input .eval_tasksets [1 ].eval_at_k = [3 , 4 ]
9395 self .config .trainer .save_interval = 4
9496 self .config .check_and_update ()
9597 _trainer_config = self .config .trainer .trainer_config
@@ -147,15 +149,17 @@ def test_trainer(self):
147149 self .config .check_and_update ()
148150 bench (self .config )
149151 parser = TensorBoardParser (os .path .join (self .config .monitor .cache_dir , "tensorboard" ))
152+ eval_tasksets = self .config .buffer .explorer_input .eval_tasksets
150153 for prefix in ["eval" , "bench" ]:
151- countdown_metrics = parser .metric_list (f"{ prefix } /countdown" )
152- copy_countdown_metrics = parser .metric_list (f"{ prefix } /copy_countdown" )
153- self .assertTrue (len (countdown_metrics ) > 0 )
154- self .assertTrue (len (copy_countdown_metrics ) > 0 )
155- countdown_metric_steps = parser .metric_steps (countdown_metrics [0 ])
156- countdown_copy_metric_steps = parser .metric_steps (copy_countdown_metrics [0 ])
157- self .assertEqual ([0 , 4 , 8 ], countdown_metric_steps )
158- self .assertEqual ([0 , 4 , 8 ], countdown_copy_metric_steps )
154+ for eval_taskset , taskset_name in zip (eval_tasksets , ["countdown" , "copy_countdown" ]):
155+ metrics = parser .metric_list (f"{ prefix } /{ taskset_name } " )
156+ self .assertTrue (len (metrics ) > 0 )
157+ for eval_stats in ["mean" , "best" , "worst" ]:
158+ for k in eval_taskset .eval_at_k :
159+ for stats in ["mean" , "std" ]:
160+ metric_name = f"{ prefix } /{ taskset_name } /score/{ eval_stats } @{ k } /{ stats } "
161+ metric_steps = parser .metric_steps (metric_name )
162+ self .assertEqual (metric_steps , [0 , 4 , 8 ])
159163
160164 def tearDown (self ):
161165 # remove dir only when the test passed
@@ -969,6 +973,7 @@ def test_trainer(self):
969973 self .config .buffer .explorer_input .eval_tasksets .append (
970974 get_unittest_dataset_config ("gsm8k" , "test" )
971975 )
976+ self .config .buffer .explorer_input .eval_tasksets [0 ].eval_at_k = [1 , 2 , 4 ]
972977 self .config .model .model_path = get_model_path ()
973978 self .config .algorithm .algorithm_type = "grpo"
974979 self .config .algorithm .advantage_fn = "grpo"
@@ -1016,11 +1021,16 @@ def test_trainer(self):
10161021 self .config .check_and_update ()
10171022 bench (self .config )
10181023 parser = TensorBoardParser (os .path .join (self .config .monitor .cache_dir , "tensorboard" ))
1024+ eval_taskset = self .config .buffer .explorer_input .eval_tasksets [0 ]
10191025 for prefix in ["eval" , "bench" ]:
10201026 gsm8k_metrics = parser .metric_list (f"{ prefix } /gsm8k" )
10211027 self .assertTrue (len (gsm8k_metrics ) > 0 )
1022- gsm8k_metric_steps = parser .metric_steps (gsm8k_metrics [0 ])
1023- self .assertEqual ([0 , 2 ], gsm8k_metric_steps )
1028+ for eval_stats in ["mean" , "best" , "worst" ]:
1029+ for k in eval_taskset .eval_at_k :
1030+ for stats in ["mean" , "std" ]:
1031+ metric_name = f"{ prefix } /gsm8k/accuracy/{ eval_stats } @{ k } /{ stats } "
1032+ metric_steps = parser .metric_steps (metric_name )
1033+ self .assertEqual (metric_steps , [0 , 2 ])
10241034
10251035 def tearDown (self ):
10261036 shutil .rmtree (self .config .checkpoint_job_dir )
0 commit comments