Skip to content

Commit c9859d5

Browse files
committed
fix test
1 parent 8e644ef commit c9859d5

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

tests/trainer/trainer_test.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,7 @@ def test_trainer(self):
10501050
self.config.model.max_model_len = 20
10511051
self.config.model.max_prompt_tokens = 5
10521052
self.config.model.max_response_tokens = 15
1053-
# self.config.model.enable_prompt_truncation = True
1053+
self.config.model.enable_prompt_truncation = True
10541054
self.config.algorithm.algorithm_type = "grpo"
10551055
self.config.algorithm.advantage_fn = "grpo"
10561056
self.config.algorithm.kl_loss_fn = "none"
@@ -1068,10 +1068,18 @@ def test_trainer(self):
10681068
actor_metrics = parser.metric_list("actor")
10691069
self.assertTrue(len(actor_metrics) > 0)
10701070
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2)
1071-
max_prompt = parser.metric_values("prompt_length/max")
1072-
self.assertEqual(max(max_prompt), 5)
1073-
min_prompt = parser.metric_values("prompt_length/min")
1074-
self.assertEqual(min(min_prompt), 5)
1071+
max_prompt_length = parser.metric_values("prompt_length/max")
1072+
self.assertEqual(max(max_prompt_length), 5)
1073+
min_prompt_length = parser.metric_values("prompt_length/min")
1074+
self.assertEqual(min(min_prompt_length), 5)
1075+
max_response_length = parser.metric_values("response_length/max")
1076+
self.assertEqual(max(max_response_length), 1)
1077+
min_response_length = parser.metric_values("response_length/min")
1078+
self.assertEqual(min(min_response_length), 1)
1079+
final_loss = parser.metric_values("actor/final_loss")
1080+
self.assertEqual(final_loss[0], 0.0)
1081+
grad_norm = parser.metric_values("actor/grad_norm")
1082+
self.assertEqual(grad_norm[0], 0.0)
10751083

10761084
def tearDown(self):
10771085
# remove dir only when the test passed

0 commit comments

Comments
 (0)