@@ -297,30 +297,49 @@ def _get_encoded_batch(rollout_batch):
297297 encoded_list = [template .encode (data , return_length = True ) for data in rollout_batch ]
298298 encoded_batch = to_device (
299299 template .data_collator (encoded_list , padding_to = get_padding_to (args )), self .device )
300- if 'cu_seq_lens_q' in encoded_batch :
301- cu_seq_lens_q = encoded_batch ['cu_seq_lens_q' ]
302- else :
303- cu_seq_lens_q = get_packed_seq_params (encoded_batch ['position_ids' ])['cu_seq_lens_q' ]
304- seq_lengths = cu_seq_lens_q [1 :] - cu_seq_lens_q [:- 1 ]
305300
306301 labels = encoded_batch ['labels' ]
307302 batch_size = len (rollout_batch )
308- max_seq_len = seq_lengths .max ().item ()
309- assert self .template .padding_free
310303
311304 truncated_mask = torch .tensor ([b ['is_truncated' ] for b in rollout_batch ],
312305 dtype = torch .bool ,
313306 device = self .device )
314307
315- # completion_mask in rmpad format [1, total_tokens]
316- completion_mask_rmpad = (labels != - 100 ).float ()
317- completion_mask , _ = pad_logps_back_to_batch (
318- logps_rmpad = completion_mask_rmpad ,
319- logits_to_keep = max_seq_len ,
320- batch_size = batch_size ,
321- seq_lengths = seq_lengths ,
322- pad_value = 0.0 )
323- completion_mask = completion_mask .bool ()
308+ if self .template .padding_free :
309+ # In padding_free mode, labels shape is [1, total_seq_len] (rmpad format)
310+ # Calculate seq_lengths from cu_seq_lens or position_ids
311+ if 'cu_seq_lens_q' in encoded_batch :
312+ cu_seq_lens_q = encoded_batch ['cu_seq_lens_q' ]
313+ else :
314+ cu_seq_lens_q = get_packed_seq_params (encoded_batch ['position_ids' ])['cu_seq_lens_q' ]
315+ seq_lengths = cu_seq_lens_q [1 :] - cu_seq_lens_q [:- 1 ]
316+ max_seq_len = seq_lengths .max ().item ()
317+
318+ # completion_mask in rmpad format [1, total_tokens]
319+ completion_mask_rmpad = (labels != - 100 ).float ()
320+ completion_mask , _ = pad_logps_back_to_batch (
321+ logps_rmpad = completion_mask_rmpad ,
322+ logits_to_keep = max_seq_len ,
323+ batch_size = batch_size ,
324+ seq_lengths = seq_lengths ,
325+ pad_value = 0.0 )
326+ completion_mask = completion_mask .bool ()
327+ else :
328+ # In non-padding_free mode, labels shape is [batch_size, seq_len] (batch format)
329+ # Calculate seq_lengths from attention_mask
330+ attention_mask = encoded_batch .get ('attention_mask' )
331+ if attention_mask is not None :
332+ # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
333+ if attention_mask .dim () == 4 :
334+ attention_mask = attention_mask [:, 0 , 0 , :]
335+ seq_lengths = attention_mask .sum (dim = - 1 ).to (torch .int64 )
336+ else :
337+ # Fallback: assume full sequence length for each sample
338+ seq_lengths = torch .full ((batch_size , ), labels .shape [- 1 ], dtype = torch .int64 , device = self .device )
339+ max_seq_len = labels .shape [- 1 ]
340+
341+ # completion_mask is already [batch_size, seq_len] in non-padding_free mode
342+ completion_mask = (labels != - 100 )
324343
325344 encoded_batch .update ({
326345 'completion_mask' : completion_mask , # [batch_size, max_seq_len]
@@ -400,10 +419,10 @@ def _get_encoded_batch(rollout_batch):
400419
401420 if self .loss_type in ['cispo' , 'dapo' ]:
402421 # Calculate num_items_in_batch
403- # Count tokens from all mini_batch_data (this includes gathered data from rollout_group)
404- total_token_count = sum ( batch_data [ 'seq_lengths' ] .sum (). item () if self . template .
405- padding_free else batch_data [ 'completion_mask' ]. sum (). item ()
406- for batch_data in mini_batch_data )
422+ # Count completion tokens from all mini_batch_data (this includes gathered data from rollout_group)
423+ # Use completion_mask .sum() for both padding_free and non-padding_free modes
424+ # since we want the count of actual completion tokens, not sequence lengths
425+ total_token_count = sum ( batch_data [ 'completion_mask' ]. sum (). item () for batch_data in mini_batch_data )
407426
408427 # All-reduce across all ranks
409428 total_token_count_tensor = torch .tensor (total_token_count , dtype = torch .int , device = self .device )
@@ -873,22 +892,34 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
873892 with torch .no_grad (), self .null_ref_context () as ref_models :
874893 assert len (ref_models ) == 1 , 'GRPO currently does not support VPP.'
875894 ref_model = ref_models [0 ]
876- ref_per_token_logps_rmpad = self .model_forward (
895+ ref_per_token_logps_raw = self .model_forward (
877896 ref_model , iter ([deepcopy (inputs )]), no_grad = True , per_token = True )['logps' ]
878- ref_per_token_logps , _ = pad_logps_back_to_batch (
879- logps_rmpad = ref_per_token_logps_rmpad ,
880- logits_to_keep = max_seq_len ,
881- batch_size = batch_size ,
882- seq_lengths = seq_lengths )
897+ if self .template .padding_free :
898+ # In padding_free mode, logps are in rmpad format [1, total_tokens]
899+ # Pad to batch format [batch_size, max_seq_len]
900+ ref_per_token_logps , _ = pad_logps_back_to_batch (
901+ logps_rmpad = ref_per_token_logps_raw ,
902+ logits_to_keep = max_seq_len ,
903+ batch_size = batch_size ,
904+ seq_lengths = seq_lengths )
905+ else :
906+ # In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
907+ ref_per_token_logps = ref_per_token_logps_raw
883908 batch ['ref_per_token_logps' ] = ref_per_token_logps
884909
885- old_per_token_logps_rmpad = self .model_forward (
910+ old_per_token_logps_raw = self .model_forward (
886911 self .unwrapped_models [0 ], iter ([deepcopy (inputs )]), no_grad = True , per_token = True )['logps' ]
887- old_per_token_logps , _ = pad_logps_back_to_batch (
888- logps_rmpad = old_per_token_logps_rmpad ,
889- logits_to_keep = max_seq_len ,
890- batch_size = batch_size ,
891- seq_lengths = seq_lengths )
912+ if self .template .padding_free :
913+ # In padding_free mode, logps are in rmpad format [1, total_tokens]
914+ # Pad to batch format [batch_size, max_seq_len]
915+ old_per_token_logps , _ = pad_logps_back_to_batch (
916+ logps_rmpad = old_per_token_logps_raw ,
917+ logits_to_keep = max_seq_len ,
918+ batch_size = batch_size ,
919+ seq_lengths = seq_lengths )
920+ else :
921+ # In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
922+ old_per_token_logps = old_per_token_logps_raw
892923 batch ['old_per_token_logps' ] = old_per_token_logps
893924
894925 return batch
@@ -985,7 +1016,16 @@ def build_pretraining_data_loader(*_args, **kwargs):
9851016 def forward_step (self , data_iterator , model ):
9861017 # train_batch_size
9871018 # return: output_tensor, loss_func
988- data = self .get_batch (data_iterator )
1019+ data = next (data_iterator )
1020+ advantages = data .pop ('advantages' )
1021+ truncated_mask = data .pop ('truncated_mask' )
1022+ seq_lengths = data .pop ('seq_lengths' )
1023+ data = self ._prepare_batch (data )
1024+ data .update ({
1025+ 'advantages' : advantages ,
1026+ 'truncated_mask' : truncated_mask ,
1027+ 'seq_lengths' : seq_lengths ,
1028+ })
9891029 data .pop ('loss_scale' , None )
9901030 inputs = self ._prepare_model_inputs (data )
9911031
@@ -995,29 +1035,36 @@ def forward_step(self, data_iterator, model):
9951035
9961036 @profiling_decorator
9971037 def loss_func (self , output_tensor : torch .Tensor , data : Dict [str , Any ]):
1038+ args = get_args ()
9981039 # Get pre-padded data in batch format [batch_size, max_seq_len]
9991040 advantages = data ['advantages' ] # [batch_size]
10001041 labels = data ['labels' ]
10011042 completion_mask = data ['completion_mask' ] # [batch_size, max_seq_len]
1002- packed_seq_params = data [ 'packed_seq_params' ]
1043+ packed_seq_params = data . get ( 'packed_seq_params' )
10031044 truncated_mask = data ['truncated_mask' ] # [batch_size]
10041045 seq_lengths = data ['seq_lengths' ] # [batch_size]
10051046 max_seq_len = completion_mask .shape [1 ]
10061047 micro_batch_size = self .micro_batch_size
10071048
1008- # Use full sequence lengths directly (get_logps returns full sequences in CP mode)
1009- lengths = packed_seq_params .cu_seqlens_q [1 :micro_batch_size
1010- + 1 ] - packed_seq_params .cu_seqlens_q [:micro_batch_size ]
1011-
1012- # get_logps with per_token=True returns rmpad format [1, total_tokens]
1013- # Pad to batch format [batch_size, max_seq_len]
1014- per_token_logps_rmpad = self .get_logps (
1015- output_tensor , labels , packed_seq_params , packed_seq_params .num_samples , per_token = True )
1016- per_token_logps , _ = pad_logps_back_to_batch (
1017- logps_rmpad = per_token_logps_rmpad ,
1018- logits_to_keep = max_seq_len ,
1019- batch_size = micro_batch_size ,
1020- seq_lengths = seq_lengths )
1049+ if args .padding_free :
1050+ # Use full sequence lengths directly (get_logps returns full sequences in CP mode)
1051+ lengths = packed_seq_params .cu_seqlens_q [1 :micro_batch_size
1052+ + 1 ] - packed_seq_params .cu_seqlens_q [:micro_batch_size ]
1053+
1054+ # get_logps with per_token=True returns rmpad format [1, total_tokens]
1055+ # Pad to batch format [batch_size, max_seq_len]
1056+ per_token_logps_rmpad = self .get_logps (
1057+ output_tensor , labels , packed_seq_params , packed_seq_params .num_samples , per_token = True )
1058+ per_token_logps , _ = pad_logps_back_to_batch (
1059+ logps_rmpad = per_token_logps_rmpad ,
1060+ logits_to_keep = max_seq_len ,
1061+ batch_size = micro_batch_size ,
1062+ seq_lengths = seq_lengths )
1063+ else :
1064+ # In non-padding_free mode, get_logps with per_token=True returns [batch_size, seq_len]
1065+ # No need to pad, already in batch format
1066+ lengths = seq_lengths
1067+ per_token_logps = self .get_logps (output_tensor , labels , packed_seq_params , micro_batch_size , per_token = True )
10211068
10221069 # Get pre-padded ref/old/rollout logps from data
10231070 ref_per_token_logps = data .get ('ref_per_token_logps' ) # [batch_size, max_seq_len] or None
@@ -1256,13 +1303,19 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
12561303 with self .stimer (bdata = True ):
12571304 data = self .get_batch (data_iterator )
12581305 data .pop ('loss_scale' , None )
1306+ input_ids = data .get ('input_ids' )
12591307 labels = data .get ('labels' )
12601308 context = torch .no_grad () if no_grad else nullcontext ()
12611309 with context :
12621310 output_tensor = forward_step_helper (model , data )
1263- packed_seq_params = data ['packed_seq_params' ]
1311+ # packed_seq_params only exists in padding_free mode
1312+ packed_seq_params = data .get ('packed_seq_params' )
1313+ if packed_seq_params is not None :
1314+ num_samples = packed_seq_params .num_samples
1315+ else :
1316+ num_samples = input_ids .shape [0 ] if input_ids is not None else labels .shape [0 ]
12641317 data ['logps' ] = None if labels is None else self .get_logps (
1265- output_tensor , labels , data [ ' packed_seq_params' ], packed_seq_params . num_samples , per_token = per_token )
1318+ output_tensor , labels , packed_seq_params , num_samples , per_token = per_token )
12661319 return data
12671320
12681321 def inputs2requests (self , inputs : Union [DataType , List [RolloutInferRequest ]]) -> List [RolloutInferRequest ]:
0 commit comments