@@ -27,7 +27,7 @@ def setUp(self):
2727 self .config .model .model_path = get_model_path ()
2828 self .config .explorer .rollout_model .engine_type = "vllm_async"
2929 self .config .algorithm .repeat_times = 3
30- self .config .explorer .rollout_model .use_v1 = True
30+ self .config .explorer .rollout_model .use_v1 = False
3131 self .config .project = "Trainer-unittest"
3232 self .config .name = f"trainer-{ datetime .now ().strftime ('%Y%m%d%H%M%S' )} "
3333 self .config .monitor .monitor_type = MonitorType .TENSORBOARD
@@ -67,6 +67,10 @@ def test_trainer(self):
6767 actor_metrics = parser .metric_list ("actor" )
6868 self .assertTrue (len (actor_metrics ) > 0 )
6969 self .assertEqual (parser .metric_max_step (actor_metrics [0 ]), 8 )
70+ actor_kl_metrics = parser .metric_list ("actor/kl" )
71+ self .assertTrue (len (actor_kl_metrics ) > 0 )
72+ critic_kl_metrics = parser .metric_list ("critic/kl" )
73+ self .assertTrue (len (critic_kl_metrics ) > 0 )
7074 response_metrics = parser .metric_list ("response_length" )
7175 self .assertTrue (len (response_metrics ) > 0 )
7276 self .assertEqual (parser .metric_max_step (response_metrics [0 ]), 8 )
@@ -86,7 +90,7 @@ def test_trainer(self):
8690 )
8791 self .assertTrue (os .path .exists (checkpoint_step_4 ))
8892 self .assertTrue (os .path .exists (checkpoint_step_8 ))
89-
93+ # TODO: Reinit will fail when using v1 engine, find a way to fix it
9094 ray .init (ignore_reinit_error = True )
9195 # test bench mode
9296 self .config .mode = "bench"
0 commit comments