|
10 | 10 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union |
11 | 11 |
|
12 | 12 | import json |
| 13 | +import numpy as np |
13 | 14 | import torch |
14 | 15 | import torch.nn.functional as F |
15 | 16 | import transformers |
@@ -998,6 +999,27 @@ def pad_sequence(sequences: List[torch.Tensor], |
998 | 999 |
|
999 | 1000 | return torch.stack(padded_sequences) |
1000 | 1001 |
|
| 1002 | + def data_collator_with_flattening(self, |
| 1003 | + batch: List[Dict[str, Any]], |
| 1004 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
| 1005 | + """ |
| 1006 | + Data collator used for padding free approach. Does the following: |
| 1007 | +
|
| 1008 | + - concatate the entire mini batch into single long sequence [1, total_tokens] |
| 1009 | + - no padding will be added, returns `input_ids`, `labels` and `position_ids` |
| 1010 | +
|
| 1011 | + Args: |
| 1012 | + batch(`List[Dict[str, Any]]`): The input data in batch |
| 1013 | + padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch |
| 1014 | + will be padded to the `longest` |
| 1015 | + """ |
| 1016 | + packed_data = {} |
| 1017 | + position_id_lengths = [len(item['input_ids']) for item in batch] |
| 1018 | + packed_data['input_ids'] = np.concatenate([item['input_ids'] for item in batch]) |
| 1019 | + packed_data['labels'] = np.concatenate([item['labels'] for item in batch]) |
| 1020 | + packed_data['position_ids'] = np.concatenate([list(range(pil)) for pil in position_id_lengths]) |
| 1021 | + return self.data_collator([packed_data], padding_to) |
| 1022 | + |
1001 | 1023 | def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]: |
1002 | 1024 | """ |
1003 | 1025 | Args: |
|
0 commit comments