@@ -325,3 +325,120 @@ def __call__(self, batch: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "tor
325325 add_flash_attention_kwargs_from_position_ids (batch )
326326
327327 return batch
328+
329+
330+ @dataclass
331+ class ClassificationDataCollatorWithPositionIDs (DataCollator ):
332+ """
333+ Reuse DataCollatorWithPositionIDs from veomni.data,
334+ but remove the part that masks out the labels corresponding to the boundary tokens of each subsequence.
335+ """
336+
337+ def __call__ (self , features : Sequence [Dict [str , "torch.Tensor" ]]) -> Dict [str , "torch.Tensor" ]:
338+ batch = {}
339+ for input_name in features [0 ].keys ():
340+ if input_name in ("input_ids" , "attention_mask" , "labels" , "position_ids" ):
341+ batch [input_name ] = torch .cat ([feature [input_name ] for feature in features ], dim = - 1 ).unsqueeze (0 )
342+ else :
343+ batch [input_name ] = default_collate ([feature [input_name ] for feature in features ])
344+
345+ if "position_ids" not in batch :
346+ batch ["position_ids" ] = torch .cat (
347+ [torch .arange (len (feature ["input_ids" ])) for feature in features ]
348+ ).unsqueeze (0 )
349+
350+ # cu_seq_lens_q should equal to cu_seq_lens_k and max_length_q should equal to max_length_k
351+ if not get_parallel_state ().sp_enabled :
352+ # We only enter here to pass down cu_seqlens and max_length when sequence parallelism is not enabled.
353+ # When sp_enabled is True, position_ids will be padded later, so we calculate them after padding
354+ cu_seq_lens_q , _ , _ , _ = add_flash_attention_kwargs_from_position_ids (batch )
355+ else :
356+ # Still need cu_seq_lens_q for label masking even when sp_enabled
357+ (cu_seq_lens_q , _ ), (_ , _ ) = prepare_fa_kwargs_from_position_ids (batch ["position_ids" ])
358+
359+ return batch
360+
361+
362+ @dataclass
363+ class ClassificationTextSequenceShardCollator (DataCollator ):
364+ """
365+ Patch of TextSequenceShardCollator for SeqCls token-level labels:
366+ - NO label shift
367+ - NO masking of last token
368+ Keep everything else identical.
369+ """
370+
371+ rmpad : bool
372+ rmpad_with_pos_ids : bool
373+ pad_token_id : int = 0
374+
375+ def __post_init__ (self ):
376+ self .sp_size = get_parallel_state ().sp_size
377+ self .sp_rank = get_parallel_state ().sp_rank
378+
379+ def sp_slice (self , tensor : torch .Tensor , dim : int = - 1 ) -> torch .Tensor :
380+ seq_length = tensor .size (dim )
381+ sp_chunk_size = (seq_length + self .sp_size - 1 ) // self .sp_size
382+ return tensor .narrow (dim , self .sp_rank * sp_chunk_size , sp_chunk_size )
383+
384+ def sp_padding (
385+ self , tensor : torch .Tensor , dim : int = - 1 , pad_value : int = 0 , pad_length : int = 0 , sequential : bool = False
386+ ) -> torch .Tensor :
387+ if pad_length == 0 :
388+ return tensor
389+ pad_shape = list (tensor .shape )
390+ pad_shape [dim ] = pad_length
391+ if sequential :
392+ seq = torch .arange (pad_length , device = tensor .device , dtype = tensor .dtype )
393+ view_shape = [1 ] * tensor .ndim
394+ view_shape [dim ] = pad_length
395+ pad = seq .view (view_shape ).expand (pad_shape )
396+ else :
397+ pad = torch .full (pad_shape , fill_value = pad_value , dtype = tensor .dtype , device = tensor .device )
398+ return torch .cat ((tensor , pad ), dim = dim )
399+
400+ def __call__ (self , batch : Sequence [Dict [str , torch .Tensor ]]) -> Dict [str , torch .Tensor ]:
401+ input_ids = batch .pop ("input_ids" )
402+
403+ # CHANGED: do NOT shift labels for seq-cls token-level labels
404+ labels = batch .pop ("labels" ).contiguous ()
405+
406+ # CHANGED: do NOT mask the last token of each sequence (your class id sits there)
407+ if (not self .rmpad_with_pos_ids ) and (not self .rmpad ) and ("position_ids" not in batch ):
408+ batch ["position_ids" ] = torch .arange (0 , input_ids .size (- 1 ), device = input_ids .device ).unsqueeze (0 )
409+
410+ # sp padding
411+ seq_length = input_ids .size (- 1 )
412+ sp_chunk_size = (seq_length + self .sp_size - 1 ) // self .sp_size
413+ pad_length = sp_chunk_size * self .sp_size - seq_length
414+
415+ input_ids = self .sp_padding (input_ids , dim = - 1 , pad_value = self .pad_token_id , pad_length = pad_length )
416+ labels = self .sp_padding (labels , dim = - 1 , pad_value = IGNORE_INDEX , pad_length = pad_length )
417+
418+ if self .rmpad_with_pos_ids :
419+ batch ["attention_mask" ] = self .sp_padding (
420+ batch ["attention_mask" ], dim = - 1 , pad_value = 1 , pad_length = pad_length
421+ )
422+ else :
423+ batch ["attention_mask" ] = self .sp_padding (
424+ batch ["attention_mask" ], dim = - 1 , pad_value = 0 , pad_length = pad_length
425+ )
426+
427+ if self .rmpad :
428+ if pad_length > 0 :
429+ batch ["cu_seqlens" ] = F .pad (
430+ batch ["cu_seqlens" ], (0 , 1 ), "constant" , batch ["cu_seqlens" ][- 1 ].item () + pad_length
431+ )
432+ else :
433+ batch ["position_ids" ] = self .sp_padding (
434+ batch ["position_ids" ], dim = - 1 , pad_value = 0 , pad_length = pad_length , sequential = True
435+ )
436+
437+ # sp slice
438+ batch ["input_ids" ] = self .sp_slice (input_ids , dim = - 1 )
439+ batch ["labels" ] = self .sp_slice (labels , dim = - 1 )
440+
441+ if not self .rmpad :
442+ add_flash_attention_kwargs_from_position_ids (batch )
443+
444+ return batch
0 commit comments