@@ -84,14 +84,11 @@ def test_trainer(self):
8484 self .config .buffer .explorer_input .taskset .task_selector = TaskSelectorConfig (
8585 selector_type = "shuffle" , seed = 42
8686 )
87- self .config .buffer .explorer_input .eval_tasksets .append (
88- get_unittest_dataset_config ("countdown" , "test" )
89- )
90- self .config .buffer .explorer_input .eval_tasksets .append (
91- get_unittest_dataset_config ("copy_countdown" , "test" )
92- )
93- self .config .buffer .explorer_input .eval_tasksets [0 ].eval_at_k = [1 , 2 ]
94- self .config .buffer .explorer_input .eval_tasksets [1 ].eval_at_k = [3 , 4 ]
87+ eval_tasksets = self .config .buffer .explorer_input .eval_tasksets
88+ eval_tasksets .append (get_unittest_dataset_config ("countdown" , "test" ))
89+ eval_tasksets .append (get_unittest_dataset_config ("copy_countdown" , "test" ))
90+ eval_tasksets [0 ].repeat_times = 4
91+ eval_tasksets [1 ].repeat_times = 4
9592 self .config .trainer .save_interval = 4
9693 self .config .check_and_update ()
9794 _trainer_config = self .config .trainer .trainer_config
@@ -149,13 +146,12 @@ def test_trainer(self):
149146 self .config .check_and_update ()
150147 bench (self .config )
151148 parser = TensorBoardParser (os .path .join (self .config .monitor .cache_dir , "tensorboard" ))
152- eval_tasksets = self .config .buffer .explorer_input .eval_tasksets
153149 for prefix in ["eval" , "bench" ]:
154- for eval_taskset , taskset_name in zip ( eval_tasksets , ["countdown" , "copy_countdown" ]) :
150+ for taskset_name in ["countdown" , "copy_countdown" ]:
155151 metrics = parser .metric_list (f"{ prefix } /{ taskset_name } " )
156152 self .assertTrue (len (metrics ) > 0 )
157153 for eval_stats in ["mean" , "best" , "worst" ]:
158- for k in eval_taskset . eval_at_k :
154+ for k in [ 2 , 4 ] :
159155 for stats in ["mean" , "std" ]:
160156 metric_name = f"{ prefix } /{ taskset_name } /score/{ eval_stats } @{ k } /{ stats } "
161157 metric_steps = parser .metric_steps (metric_name )
@@ -973,7 +969,7 @@ def test_trainer(self):
973969 self .config .buffer .explorer_input .eval_tasksets .append (
974970 get_unittest_dataset_config ("gsm8k" , "test" )
975971 )
976- self .config .buffer .explorer_input .eval_tasksets [0 ].eval_at_k = [ 1 , 2 , 4 ]
972+ self .config .buffer .explorer_input .eval_tasksets [0 ].repeat_times = 8
977973 self .config .model .model_path = get_model_path ()
978974 self .config .algorithm .algorithm_type = "grpo"
979975 self .config .algorithm .advantage_fn = "grpo"
@@ -1021,12 +1017,11 @@ def test_trainer(self):
10211017 self .config .check_and_update ()
10221018 bench (self .config )
10231019 parser = TensorBoardParser (os .path .join (self .config .monitor .cache_dir , "tensorboard" ))
1024- eval_taskset = self .config .buffer .explorer_input .eval_tasksets [0 ]
10251020 for prefix in ["eval" , "bench" ]:
10261021 gsm8k_metrics = parser .metric_list (f"{ prefix } /gsm8k" )
10271022 self .assertTrue (len (gsm8k_metrics ) > 0 )
10281023 for eval_stats in ["mean" , "best" , "worst" ]:
1029- for k in eval_taskset . eval_at_k :
1024+ for k in [ 2 , 4 , 8 ] :
10301025 for stats in ["mean" , "std" ]:
10311026 metric_name = f"{ prefix } /gsm8k/accuracy/{ eval_stats } @{ k } /{ stats } "
10321027 metric_steps = parser .metric_steps (metric_name )
0 commit comments