Skip to content

Commit 526d0aa

Browse files
committed
fix trainer test
1 parent 4c773bb commit 526d0aa

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/trainer/trainer_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def setUp(self):
2727
self.config.model.model_path = get_model_path()
2828
self.config.explorer.rollout_model.engine_type = "vllm_async"
2929
self.config.algorithm.repeat_times = 3
30-
self.config.explorer.rollout_model.use_v1 = True
30+
self.config.explorer.rollout_model.use_v1 = False
3131
self.config.project = "Trainer-unittest"
3232
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
3333
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
@@ -67,6 +67,10 @@ def test_trainer(self):
6767
actor_metrics = parser.metric_list("actor")
6868
self.assertTrue(len(actor_metrics) > 0)
6969
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
70+
actor_kl_metrics = parser.metric_list("actor/kl")
71+
self.assertTrue(len(actor_kl_metrics) > 0)
72+
critic_kl_metrics = parser.metric_list("critic/kl")
73+
self.assertTrue(len(critic_kl_metrics) > 0)
7074
response_metrics = parser.metric_list("response_length")
7175
self.assertTrue(len(response_metrics) > 0)
7276
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
@@ -86,7 +90,7 @@ def test_trainer(self):
8690
)
8791
self.assertTrue(os.path.exists(checkpoint_step_4))
8892
self.assertTrue(os.path.exists(checkpoint_step_8))
89-
93+
# TODO: Reinit will fail when using v1 engine, find a way to fix it
9094
ray.init(ignore_reinit_error=True)
9195
# test bench mode
9296
self.config.mode = "bench"

0 commit comments

Comments
 (0)