@@ -1050,7 +1050,7 @@ def test_trainer(self):
10501050 self .config .model .max_model_len = 20
10511051 self .config .model .max_prompt_tokens = 5
10521052 self .config .model .max_response_tokens = 15
1053- # self.config.model.enable_prompt_truncation = True
1053+ self .config .model .enable_prompt_truncation = True
10541054 self .config .algorithm .algorithm_type = "grpo"
10551055 self .config .algorithm .advantage_fn = "grpo"
10561056 self .config .algorithm .kl_loss_fn = "none"
@@ -1068,10 +1068,18 @@ def test_trainer(self):
10681068 actor_metrics = parser .metric_list ("actor" )
10691069 self .assertTrue (len (actor_metrics ) > 0 )
10701070 self .assertEqual (parser .metric_max_step (actor_metrics [0 ]), 2 )
1071- max_prompt = parser .metric_values ("prompt_length/max" )
1072- self .assertEqual (max (max_prompt ), 5 )
1073- min_prompt = parser .metric_values ("prompt_length/min" )
1074- self .assertEqual (min (min_prompt ), 5 )
1071+ max_prompt_length = parser .metric_values ("prompt_length/max" )
1072+ self .assertEqual (max (max_prompt_length ), 5 )
1073+ min_prompt_length = parser .metric_values ("prompt_length/min" )
1074+ self .assertEqual (min (min_prompt_length ), 5 )
1075+ max_response_length = parser .metric_values ("response_length/max" )
1076+ self .assertEqual (max (max_response_length ), 1 )
1077+ min_response_length = parser .metric_values ("response_length/min" )
1078+ self .assertEqual (min (min_response_length ), 1 )
1079+ final_loss = parser .metric_values ("actor/final_loss" )
1080+ self .assertEqual (final_loss [0 ], 0.0 )
1081+ grad_norm = parser .metric_values ("actor/grad_norm" )
1082+ self .assertEqual (grad_norm [0 ], 0.0 )
10751083
10761084 def tearDown (self ):
10771085 # remove dir only when the test passed
0 commit comments