@@ -263,12 +263,24 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
263263
264264 self .strategy .optimizer_step (self .actor_optim , self .actor , self .actor_scheduler , name = "actor" )
265265
266- # Calculate response length statistics
266+ # Calculate response length statistics (action tokens)
267267 response_lengths = micro_batch ['action_mask' ].sum (dim = 1 ).float ()
268268 avg_response_length = response_lengths .mean ().item ()
269269 max_response_length = response_lengths .max ().item ()
270270 min_response_length = response_lengths .min ().item ()
271271
272+ # Calculate prompt length statistics (total - action tokens)
273+ total_lengths = micro_batch ['attention_mask' ].sum (dim = 1 ).float ()
274+ prompt_lengths = total_lengths - response_lengths
275+ avg_prompt_length = prompt_lengths .mean ().item ()
276+ max_prompt_length = prompt_lengths .max ().item ()
277+ min_prompt_length = prompt_lengths .min ().item ()
278+
279+ # Calculate total sequence length statistics
280+ avg_total_length = total_lengths .mean ().item ()
281+ max_total_length = total_lengths .max ().item ()
282+ min_total_length = total_lengths .min ().item ()
283+
272284 # Calculate log_probs statistics
273285 valid_log_probs = action_log_probs [micro_batch ['action_mask' ] > 0 ]
274286 avg_log_prob = valid_log_probs .mean ().item () if valid_log_probs .numel () > 0 else 0.0
@@ -288,10 +300,18 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
288300 # "approx_kl": approx_kl.detach().float().mean().item(),
289301 "cur_old_kl" : approx_kl .detach ().float ().mean ().item (),
290302 "iter" : self .train_iter ,
291- # Response length statistics
303+ # Response length statistics (action tokens)
292304 "response_length_avg" : avg_response_length ,
293305 "response_length_max" : max_response_length ,
294306 "response_length_min" : min_response_length ,
307+ # Prompt length statistics (context tokens)
308+ "prompt_length_avg" : avg_prompt_length ,
309+ "prompt_length_max" : max_prompt_length ,
310+ "prompt_length_min" : min_prompt_length ,
311+ # Total sequence length statistics
312+ "total_length_avg" : avg_total_length ,
313+ "total_length_max" : max_total_length ,
314+ "total_length_min" : min_total_length ,
295315 # Log prob and ratio statistics
296316 "log_prob_avg" : avg_log_prob ,
297317 "ratio_avg" : avg_ratio ,
0 commit comments