@@ -67,6 +67,10 @@ def test_trainer(self):
6767 actor_metrics = parser .metric_list ("actor" )
6868 self .assertTrue (len (actor_metrics ) > 0 )
6969 self .assertEqual (parser .metric_max_step (actor_metrics [0 ]), 8 )
70+ actor_kl_metrics = parser .metric_list ("actor/kl" )
71+ self .assertTrue (len (actor_kl_metrics ) > 0 )
72+ critic_kl_metrics = parser .metric_list ("critic/kl" )
73+ self .assertTrue (len (critic_kl_metrics ) > 0 )
7074 response_metrics = parser .metric_list ("response_length" )
7175 self .assertTrue (len (response_metrics ) > 0 )
7276 self .assertEqual (parser .metric_max_step (response_metrics [0 ]), 8 )
@@ -86,7 +90,7 @@ def test_trainer(self):
8690 )
8791 self .assertTrue (os .path .exists (checkpoint_step_4 ))
8892 self .assertTrue (os .path .exists (checkpoint_step_8 ))
89-
93+ # TODO: Reinit will fail when using v1 engine, find a way to fix it
9094 ray .init (ignore_reinit_error = True )
9195 # test bench mode
9296 self .config .mode = "bench"
@@ -118,7 +122,7 @@ def test_trainer(self):
118122 self .config .algorithm .algorithm_type = AlgorithmType .GRPO
119123 self .config .algorithm .repeat_times = 4
120124 # self.config.algorithm.repeat_times = 8 # TODO: used for real testing
121- self .config .algorithm .advantage_fn_type = "grpo_adv_fn "
125+ self .config .algorithm .advantage_fn = "grpo "
122126 self .config .algorithm .advantage_fn_args = {}
123127 # self.config.buffer.batch_size = 96 # TODO: used for real testing
124128 self .config .buffer .explorer_input .taskset = get_unittest_dataset_config ("gsm8k" )
@@ -143,8 +147,6 @@ def test_trainer(self):
143147 # self.assertTrue(0.4 < rewards[1] < 0.55)
144148 # self.assertTrue(0.6 < rewards[2] < 0.7)
145149 # self.assertTrue(0.6 < rewards[3] < 0.7)
146- ray .shutdown (_exiting_interpreter = True )
147- # check checkpoint
148150
149151 def tearDown (self ):
150152 # remove dir only when the test passed
@@ -157,7 +159,7 @@ def test_trainer(self):
157159 # test both mode
158160 self .config .algorithm .algorithm_type = AlgorithmType .GRPO
159161 self .config .algorithm .repeat_times = 4
160- self .config .algorithm .advantage_fn_type = "grpo_adv_fn "
162+ self .config .algorithm .advantage_fn = "grpo "
161163 self .config .algorithm .advantage_fn_args = {}
162164 self .config .buffer .explorer_input .taskset = get_unittest_dataset_config ("gsm8k" )
163165 self .config .buffer .trainer_input .sft_warmup_steps = 2
@@ -180,8 +182,6 @@ def test_trainer(self):
180182 response_metrics = parser .metric_list ("response_length" )
181183 self .assertTrue (len (response_metrics ) > 0 )
182184 self .assertEqual (parser .metric_max_step (response_metrics [0 ]), 4 )
183- ray .shutdown (_exiting_interpreter = True )
184- # check checkpoint
185185
186186 def tearDown (self ):
187187 # remove dir only when the test passed
@@ -207,8 +207,6 @@ def test_trainer(self):
207207 actor_metrics = parser .metric_list ("actor" )
208208 self .assertTrue (len (actor_metrics ) > 0 )
209209 self .assertEqual (parser .metric_max_step (actor_metrics [0 ]), 4 )
210- ray .shutdown (_exiting_interpreter = True )
211- # check checkpoint
212210
213211 def tearDown (self ):
214212 # remove dir only when the test passed
0 commit comments