Skip to content

Commit 6b30748

Browse files
authored
support batch flattening collator (#2499)
1 parent 5431a57 commit 6b30748

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

swift/llm/utils/template.py

100644100755
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union
1111

1212
import json
13+
import numpy as np
1314
import torch
1415
import torch.nn.functional as F
1516
import transformers
@@ -998,6 +999,27 @@ def pad_sequence(sequences: List[torch.Tensor],
998999

9991000
return torch.stack(padded_sequences)
10001001

1002+
def data_collator_with_flattening(self,
1003+
batch: List[Dict[str, Any]],
1004+
padding_to: Optional[int] = None) -> Dict[str, Any]:
1005+
"""
1006+
Data collator used for padding free approach. Does the following:
1007+
1008+
- concatate the entire mini batch into single long sequence [1, total_tokens]
1009+
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
1010+
1011+
Args:
1012+
batch(`List[Dict[str, Any]]`): The input data in batch
1013+
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
1014+
will be padded to the `longest`
1015+
"""
1016+
packed_data = {}
1017+
position_id_lengths = [len(item['input_ids']) for item in batch]
1018+
packed_data['input_ids'] = np.concatenate([item['input_ids'] for item in batch])
1019+
packed_data['labels'] = np.concatenate([item['labels'] for item in batch])
1020+
packed_data['position_ids'] = np.concatenate([list(range(pil)) for pil in position_id_lengths])
1021+
return self.data_collator([packed_data], padding_to)
1022+
10011023
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
10021024
"""
10031025
Args:

0 commit comments

Comments
 (0)