Skip to content

Commit d9a3a5e

Browse files
committed
化繁为简,优化之前基于default_collate的long_text_data_collate为更明确易理解的ser_collate
1 parent f0a03ac commit d9a3a5e

File tree

2 files changed

+58
-61
lines changed

2 files changed

+58
-61
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .transforms import * # NOQA
2-
from .utils import long_text_data_collate
2+
from .utils import ser_collate
33
from .xfund_dataset import XFUNDDataset
44

5-
__all__ = ['XFUNDDataset', 'long_text_data_collate']
5+
__all__ = ['XFUNDDataset', 'ser_collate']
Lines changed: 56 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,64 @@
1-
from typing import Any, Mapping, Sequence
1+
from typing import Dict, Sequence
22

33
import torch
44
from mmengine.dataset.utils import COLLATE_FUNCTIONS
5-
from mmengine.structures import BaseDataElement
65

76

87
@COLLATE_FUNCTIONS.register_module()
9-
def long_text_data_collate(data_batch: Sequence, training: bool = True) -> Any:
10-
"""This code is referenced from
11-
``mmengine.dataset.utils.default_collate``"""
12-
data_item = data_batch[0]
13-
data_item_type = type(data_item)
8+
def ser_collate(data_batch: Sequence, training: bool = True) -> Dict:
9+
"""A collate function designed for SER.
1410
15-
if isinstance(data_item, (BaseDataElement, str, bytes)):
16-
return data_batch
17-
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
18-
# named_tuple
19-
return data_item_type(*(long_text_data_collate(samples, training)
20-
for samples in zip(*data_batch)))
21-
elif isinstance(data_item, list):
22-
flattened_data_batch = [
23-
sub_item for item in data_batch for sub_item in item
24-
]
25-
if training:
26-
return flattened_data_batch[:len(data_batch)]
27-
else:
28-
return flattened_data_batch
29-
elif isinstance(data_item, Sequence):
30-
# check to make sure that the data_itements in batch have
31-
# consistent size
32-
it = iter(data_batch)
33-
data_item_size = len(next(it))
34-
if not all(len(data_item) == data_item_size for data_item in it):
35-
raise RuntimeError(
36-
'each data_itement in list of batch should be of equal size')
37-
transposed = list(zip(*data_batch))
11+
Args:
12+
data_batch (Sequence): Data sampled from dataset.
13+
Like:
14+
[
15+
{
16+
'inputs': {'input_ids': ..., 'bbox': ..., ...},
17+
'data_samples': ['SERDataSample_1']
18+
},
19+
{
20+
'inputs': {'input_ids': ..., 'bbox': ..., ...},
21+
'data_samples': ['SERDataSample_1', 'SERDataSample_2', ...]
22+
},
23+
...
24+
]
25+
training (bool): whether training process or not.
3826
39-
if isinstance(data_item, tuple):
40-
return [
41-
long_text_data_collate(samples, training)
42-
for samples in transposed
43-
] # Compat with Pytorch.
44-
else:
45-
try:
46-
return data_item_type([
47-
long_text_data_collate(samples, training)
48-
for samples in transposed
49-
])
50-
except TypeError:
51-
# The sequence type may not support `__init__(iterable)`
52-
# (e.g., `range`).
53-
return [
54-
long_text_data_collate(samples, training)
55-
for samples in transposed
56-
]
57-
elif isinstance(data_item, Mapping):
58-
return data_item_type({
59-
key: long_text_data_collate([d[key] for d in data_batch], training)
60-
for key in data_item
61-
})
62-
else:
63-
concat_data_batch = torch.concat(data_batch, dim=0)
64-
if training:
65-
return concat_data_batch[:len(data_batch)]
66-
else:
67-
return concat_data_batch
27+
Note:
28+
Different from ``default_collate`` in pytorch or in mmengine,
29+
``ser_collate`` can accept `inputs` tensor and `data_samples`
30+
list with the different shape.
31+
32+
Returns:
33+
transposed (Dict): A dict have two elements,
34+
the first element `inputs` is a dict
35+
the second element `data_samples` is a list
36+
"""
37+
batch_size = len(data_batch)
38+
# transpose `inputs`, which is a dict.
39+
batch_inputs = [data_item['inputs'] for data_item in data_batch]
40+
batch_inputs_item = batch_inputs[0]
41+
transposed_batch_inputs = {}
42+
for key in batch_inputs_item:
43+
concat_value = torch.concat([d[key] for d in batch_inputs], dim=0)
44+
# TODO: because long text will be truncated, the concat_value
45+
# cannot be sliced directly when training=False.
46+
# How to support batch inference?
47+
transposed_batch_inputs[key] = concat_value[:batch_size] \
48+
if training else concat_value
49+
# transpose `data_samples`, which is a list.
50+
batch_data_samples = [
51+
data_item['data_samples'] for data_item in data_batch
52+
]
53+
flattened = [sub_item for item in batch_data_samples for sub_item in item]
54+
# TODO: because long text will be truncated, the concat_value
55+
# cannot be sliced directly when training=False.
56+
# How to support batch inference?
57+
transposed_batch_data_samples = flattened[:batch_size] \
58+
if training else flattened
59+
60+
transposed = {
61+
'inputs': transposed_batch_inputs,
62+
'data_samples': transposed_batch_data_samples
63+
}
64+
return transposed

0 commit comments

Comments
 (0)