|
1 | | -from typing import Any, Mapping, Sequence |
| 1 | +from typing import Dict, Sequence |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | from mmengine.dataset.utils import COLLATE_FUNCTIONS |
5 | | -from mmengine.structures import BaseDataElement |
6 | 5 |
|
7 | 6 |
|
8 | 7 | @COLLATE_FUNCTIONS.register_module() |
9 | | -def long_text_data_collate(data_batch: Sequence, training: bool = True) -> Any: |
10 | | - """This code is referenced from |
11 | | - ``mmengine.dataset.utils.default_collate``""" |
12 | | - data_item = data_batch[0] |
13 | | - data_item_type = type(data_item) |
| 8 | +def ser_collate(data_batch: Sequence, training: bool = True) -> Dict: |
| 9 | + """A collate function designed for SER. |
14 | 10 |
|
15 | | - if isinstance(data_item, (BaseDataElement, str, bytes)): |
16 | | - return data_batch |
17 | | - elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): |
18 | | - # named_tuple |
19 | | - return data_item_type(*(long_text_data_collate(samples, training) |
20 | | - for samples in zip(*data_batch))) |
21 | | - elif isinstance(data_item, list): |
22 | | - flattened_data_batch = [ |
23 | | - sub_item for item in data_batch for sub_item in item |
24 | | - ] |
25 | | - if training: |
26 | | - return flattened_data_batch[:len(data_batch)] |
27 | | - else: |
28 | | - return flattened_data_batch |
29 | | - elif isinstance(data_item, Sequence): |
30 | | - # check to make sure that the data_itements in batch have |
31 | | - # consistent size |
32 | | - it = iter(data_batch) |
33 | | - data_item_size = len(next(it)) |
34 | | - if not all(len(data_item) == data_item_size for data_item in it): |
35 | | - raise RuntimeError( |
36 | | - 'each data_itement in list of batch should be of equal size') |
37 | | - transposed = list(zip(*data_batch)) |
| 11 | + Args: |
| 12 | + data_batch (Sequence): Data sampled from dataset. |
| 13 | + Like: |
| 14 | + [ |
| 15 | + { |
| 16 | + 'inputs': {'input_ids': ..., 'bbox': ..., ...}, |
| 17 | + 'data_samples': ['SERDataSample_1'] |
| 18 | + }, |
| 19 | + { |
| 20 | + 'inputs': {'input_ids': ..., 'bbox': ..., ...}, |
| 21 | + 'data_samples': ['SERDataSample_1', 'SERDataSample_2', ...] |
| 22 | + }, |
| 23 | + ... |
| 24 | + ] |
| 25 | + training (bool): whether training process or not. |
38 | 26 |
|
39 | | - if isinstance(data_item, tuple): |
40 | | - return [ |
41 | | - long_text_data_collate(samples, training) |
42 | | - for samples in transposed |
43 | | - ] # Compat with Pytorch. |
44 | | - else: |
45 | | - try: |
46 | | - return data_item_type([ |
47 | | - long_text_data_collate(samples, training) |
48 | | - for samples in transposed |
49 | | - ]) |
50 | | - except TypeError: |
51 | | - # The sequence type may not support `__init__(iterable)` |
52 | | - # (e.g., `range`). |
53 | | - return [ |
54 | | - long_text_data_collate(samples, training) |
55 | | - for samples in transposed |
56 | | - ] |
57 | | - elif isinstance(data_item, Mapping): |
58 | | - return data_item_type({ |
59 | | - key: long_text_data_collate([d[key] for d in data_batch], training) |
60 | | - for key in data_item |
61 | | - }) |
62 | | - else: |
63 | | - concat_data_batch = torch.concat(data_batch, dim=0) |
64 | | - if training: |
65 | | - return concat_data_batch[:len(data_batch)] |
66 | | - else: |
67 | | - return concat_data_batch |
| 27 | + Note: |
| 28 | + Different from ``default_collate`` in pytorch or in mmengine, |
| 29 | + ``ser_collate`` can accept `inputs` tensor and `data_samples` |
| 30 | + list with the different shape. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + transposed (Dict): A dict have two elements, |
| 34 | + the first element `inputs` is a dict |
| 35 | + the second element `data_samples` is a list |
| 36 | + """ |
| 37 | + batch_size = len(data_batch) |
| 38 | + # transpose `inputs`, which is a dict. |
| 39 | + batch_inputs = [data_item['inputs'] for data_item in data_batch] |
| 40 | + batch_inputs_item = batch_inputs[0] |
| 41 | + transposed_batch_inputs = {} |
| 42 | + for key in batch_inputs_item: |
| 43 | + concat_value = torch.concat([d[key] for d in batch_inputs], dim=0) |
| 44 | + # TODO: because long text will be truncated, the concat_value |
| 45 | + # cannot be sliced directly when training=False. |
| 46 | + # How to support batch inference? |
| 47 | + transposed_batch_inputs[key] = concat_value[:batch_size] \ |
| 48 | + if training else concat_value |
| 49 | + # transpose `data_samples`, which is a list. |
| 50 | + batch_data_samples = [ |
| 51 | + data_item['data_samples'] for data_item in data_batch |
| 52 | + ] |
| 53 | + flattened = [sub_item for item in batch_data_samples for sub_item in item] |
| 54 | + # TODO: because long text will be truncated, the concat_value |
| 55 | + # cannot be sliced directly when training=False. |
| 56 | + # How to support batch inference? |
| 57 | + transposed_batch_data_samples = flattened[:batch_size] \ |
| 58 | + if training else flattened |
| 59 | + |
| 60 | + transposed = { |
| 61 | + 'inputs': transposed_batch_inputs, |
| 62 | + 'data_samples': transposed_batch_data_samples |
| 63 | + } |
| 64 | + return transposed |
0 commit comments