Skip to content

Commit e45423c

Browse files
committed
update data collator to support embedding classification
1 parent 78f3965 commit e45423c

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

veomni/data/data_collator.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,120 @@ def __call__(self, batch: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "tor
325325
add_flash_attention_kwargs_from_position_ids(batch)
326326

327327
return batch
328+
329+
330+
@dataclass
331+
class ClassificationDataCollatorWithPositionIDs(DataCollator):
332+
"""
333+
Reuse DataCollatorWithPositionIDs from veomni.data,
334+
but remove the part that masks out the labels corresponding to the boundary tokens of each subsequence.
335+
"""
336+
337+
def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]:
338+
batch = {}
339+
for input_name in features[0].keys():
340+
if input_name in ("input_ids", "attention_mask", "labels", "position_ids"):
341+
batch[input_name] = torch.cat([feature[input_name] for feature in features], dim=-1).unsqueeze(0)
342+
else:
343+
batch[input_name] = default_collate([feature[input_name] for feature in features])
344+
345+
if "position_ids" not in batch:
346+
batch["position_ids"] = torch.cat(
347+
[torch.arange(len(feature["input_ids"])) for feature in features]
348+
).unsqueeze(0)
349+
350+
# cu_seq_lens_q should equal to cu_seq_lens_k and max_length_q should equal to max_length_k
351+
if not get_parallel_state().sp_enabled:
352+
# We only enter here to pass down cu_seqlens and max_length when sequence parallelism is not enabled.
353+
# When sp_enabled is True, position_ids will be padded later, so we calculate them after padding
354+
cu_seq_lens_q, _, _, _ = add_flash_attention_kwargs_from_position_ids(batch)
355+
else:
356+
# Still need cu_seq_lens_q for label masking even when sp_enabled
357+
(cu_seq_lens_q, _), (_, _) = prepare_fa_kwargs_from_position_ids(batch["position_ids"])
358+
359+
return batch
360+
361+
362+
@dataclass
363+
class ClassificationTextSequenceShardCollator(DataCollator):
364+
"""
365+
Patch of TextSequenceShardCollator for SeqCls token-level labels:
366+
- NO label shift
367+
- NO masking of last token
368+
Keep everything else identical.
369+
"""
370+
371+
rmpad: bool
372+
rmpad_with_pos_ids: bool
373+
pad_token_id: int = 0
374+
375+
def __post_init__(self):
376+
self.sp_size = get_parallel_state().sp_size
377+
self.sp_rank = get_parallel_state().sp_rank
378+
379+
def sp_slice(self, tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
380+
seq_length = tensor.size(dim)
381+
sp_chunk_size = (seq_length + self.sp_size - 1) // self.sp_size
382+
return tensor.narrow(dim, self.sp_rank * sp_chunk_size, sp_chunk_size)
383+
384+
def sp_padding(
385+
self, tensor: torch.Tensor, dim: int = -1, pad_value: int = 0, pad_length: int = 0, sequential: bool = False
386+
) -> torch.Tensor:
387+
if pad_length == 0:
388+
return tensor
389+
pad_shape = list(tensor.shape)
390+
pad_shape[dim] = pad_length
391+
if sequential:
392+
seq = torch.arange(pad_length, device=tensor.device, dtype=tensor.dtype)
393+
view_shape = [1] * tensor.ndim
394+
view_shape[dim] = pad_length
395+
pad = seq.view(view_shape).expand(pad_shape)
396+
else:
397+
pad = torch.full(pad_shape, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device)
398+
return torch.cat((tensor, pad), dim=dim)
399+
400+
def __call__(self, batch: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
401+
input_ids = batch.pop("input_ids")
402+
403+
# CHANGED: do NOT shift labels for seq-cls token-level labels
404+
labels = batch.pop("labels").contiguous()
405+
406+
# CHANGED: do NOT mask the last token of each sequence (your class id sits there)
407+
if (not self.rmpad_with_pos_ids) and (not self.rmpad) and ("position_ids" not in batch):
408+
batch["position_ids"] = torch.arange(0, input_ids.size(-1), device=input_ids.device).unsqueeze(0)
409+
410+
# sp padding
411+
seq_length = input_ids.size(-1)
412+
sp_chunk_size = (seq_length + self.sp_size - 1) // self.sp_size
413+
pad_length = sp_chunk_size * self.sp_size - seq_length
414+
415+
input_ids = self.sp_padding(input_ids, dim=-1, pad_value=self.pad_token_id, pad_length=pad_length)
416+
labels = self.sp_padding(labels, dim=-1, pad_value=IGNORE_INDEX, pad_length=pad_length)
417+
418+
if self.rmpad_with_pos_ids:
419+
batch["attention_mask"] = self.sp_padding(
420+
batch["attention_mask"], dim=-1, pad_value=1, pad_length=pad_length
421+
)
422+
else:
423+
batch["attention_mask"] = self.sp_padding(
424+
batch["attention_mask"], dim=-1, pad_value=0, pad_length=pad_length
425+
)
426+
427+
if self.rmpad:
428+
if pad_length > 0:
429+
batch["cu_seqlens"] = F.pad(
430+
batch["cu_seqlens"], (0, 1), "constant", batch["cu_seqlens"][-1].item() + pad_length
431+
)
432+
else:
433+
batch["position_ids"] = self.sp_padding(
434+
batch["position_ids"], dim=-1, pad_value=0, pad_length=pad_length, sequential=True
435+
)
436+
437+
# sp slice
438+
batch["input_ids"] = self.sp_slice(input_ids, dim=-1)
439+
batch["labels"] = self.sp_slice(labels, dim=-1)
440+
441+
if not self.rmpad:
442+
add_flash_attention_kwargs_from_position_ids(batch)
443+
444+
return batch

0 commit comments

Comments
 (0)