Skip to content

Commit d3e16ad

Browse files
committed
[Fix] 回退删除huggingface dataset形式,没意义。修改ser/re packer的metainfo信息,阶段性添加SERDataset
1 parent 078cc83 commit d3e16ad

File tree

12 files changed

+275
-61
lines changed

12 files changed

+275
-61
lines changed

configs/re/_base_/datasets/xfund_zh_huggingface.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

configs/ser/_base_/datasets/xfund_zh_huggingface.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

mmocr/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from .recog_lmdb_dataset import RecogLMDBDataset
66
from .recog_text_dataset import RecogTextDataset
77
from .samplers import * # NOQA
8+
from .ser_dataset import SERDataset
89
from .transforms import * # NOQA
910
from .wildreceipt_dataset import WildReceiptDataset
1011

1112
__all__ = [
1213
'IcdarDataset', 'OCRDataset', 'RecogLMDBDataset', 'RecogTextDataset',
13-
'WildReceiptDataset', 'ConcatDataset'
14+
'WildReceiptDataset', 'ConcatDataset', 'SERDataset'
1415
]

mmocr/datasets/preparers/config_generators/re_config_generator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ def _gen_dataset_config(self) -> str:
9090
cfg += ' type=\'REDataset\',\n'
9191
cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501
9292
cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n'
93-
if ann_dict['split'] == 'train':
94-
cfg += ' filter_cfg=dict(filter_empty_gt=True, min_size=32),\n' # noqa: E501
95-
elif ann_dict['split'] in ['test', 'val']:
93+
if ann_dict['split'] in ['test', 'val']:
9694
cfg += ' test_mode=True,\n'
9795
cfg += ' pipeline=None)\n'
9896
return cfg

mmocr/datasets/preparers/config_generators/ser_config_generator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ def _gen_dataset_config(self) -> str:
9090
cfg += ' type=\'SERDataset\',\n'
9191
cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501
9292
cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n'
93-
if ann_dict['split'] == 'train':
94-
cfg += ' filter_cfg=dict(filter_empty_gt=True, min_size=32),\n' # noqa: E501
95-
elif ann_dict['split'] in ['test', 'val']:
93+
if ann_dict['split'] in ['test', 'val']:
9694
cfg += ' test_mode=True,\n'
9795
cfg += ' pipeline=None)\n'
9896
return cfg

mmocr/datasets/preparers/packers/re_packer.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,22 @@ class REPacker(BasePacker):
2222
"task_name": "re",
2323
"labels": ['answer', 'header', 'other', 'question'],
2424
"id2label": {
25-
"0": "answer",
26-
"1": "header",
27-
"2": "other",
28-
"3": "question"
25+
"0": "O",
26+
"1": "B-ANSWER",
27+
"2": "I-ANSWER",
28+
"3": "B-HEADER",
29+
"4": "I-HEADER",
30+
"5": "B-QUESTION",
31+
"6": "I-QUESTION"
2932
},
3033
"label2id": {
31-
"answer": 0,
32-
"header": 1,
33-
"other": 2,
34-
"question": 3
34+
"O": 0,
35+
"B-ANSWER": 1,
36+
"I-ANSWER": 2,
37+
"B-HEADER": 3,
38+
"I-HEADER": 4,
39+
"B-QUESTION": 5,
40+
"I-QUESTION": 6
3541
}
3642
},
3743
"data_list":
@@ -141,21 +147,31 @@ def add_meta(self, sample: List) -> Dict:
141147
Dict: A dict contains the meta information and samples.
142148
"""
143149

150+
def get_BIO_label_list(labels):
151+
bio_label_list = []
152+
for label in labels:
153+
if label == 'other':
154+
bio_label_list.insert(0, 'O')
155+
else:
156+
bio_label_list.append(f'B-{label.upper()}')
157+
bio_label_list.append(f'I-{label.upper()}')
158+
return bio_label_list
159+
144160
labels = []
145161
for s in sample:
146162
labels += s['instances']['labels']
147-
label_list = list(set(labels))
148-
label_list.sort()
163+
org_label_list = list(set(labels))
164+
bio_label_list = get_BIO_label_list(org_label_list)
149165

150166
meta = {
151167
'metainfo': {
152168
'dataset_type': 'REDataset',
153169
'task_name': 're',
154-
'labels': label_list,
170+
'labels': org_label_list,
155171
'id2label': {k: v
156-
for k, v in enumerate(label_list)},
172+
for k, v in enumerate(bio_label_list)},
157173
'label2id': {v: k
158-
for k, v in enumerate(label_list)}
174+
for k, v in enumerate(bio_label_list)}
159175
},
160176
'data_list': sample
161177
}

mmocr/datasets/preparers/packers/ser_packer.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,22 @@ class SERPacker(BasePacker):
2222
"task_name": "ser",
2323
"labels": ['answer', 'header', 'other', 'question'],
2424
"id2label": {
25-
"0": "answer",
26-
"1": "header",
27-
"2": "other",
28-
"3": "question"
25+
"0": "O",
26+
"1": "B-ANSWER",
27+
"2": "I-ANSWER",
28+
"3": "B-HEADER",
29+
"4": "I-HEADER",
30+
"5": "B-QUESTION",
31+
"6": "I-QUESTION"
2932
},
3033
"label2id": {
31-
"answer": 0,
32-
"header": 1,
33-
"other": 2,
34-
"question": 3
34+
"O": 0,
35+
"B-ANSWER": 1,
36+
"I-ANSWER": 2,
37+
"B-HEADER": 3,
38+
"I-HEADER": 4,
39+
"B-QUESTION": 5,
40+
"I-QUESTION": 6
3541
}
3642
},
3743
"data_list":
@@ -129,21 +135,31 @@ def add_meta(self, sample: List) -> Dict:
129135
Dict: A dict contains the meta information and samples.
130136
"""
131137

138+
def get_BIO_label_list(labels):
139+
bio_label_list = []
140+
for label in labels:
141+
if label == 'other':
142+
bio_label_list.insert(0, 'O')
143+
else:
144+
bio_label_list.append(f'B-{label.upper()}')
145+
bio_label_list.append(f'I-{label.upper()}')
146+
return bio_label_list
147+
132148
labels = []
133149
for s in sample:
134150
labels += s['instances']['labels']
135-
label_list = list(set(labels))
136-
label_list.sort()
151+
org_label_list = list(set(labels))
152+
bio_label_list = get_BIO_label_list(org_label_list)
137153

138154
meta = {
139155
'metainfo': {
140156
'dataset_type': 'SERDataset',
141157
'task_name': 'ser',
142-
'labels': label_list,
158+
'labels': org_label_list,
143159
'id2label': {k: v
144-
for k, v in enumerate(label_list)},
160+
for k, v in enumerate(bio_label_list)},
145161
'label2id': {v: k
146-
for k, v in enumerate(label_list)}
162+
for k, v in enumerate(bio_label_list)}
147163
},
148164
'data_list': sample
149165
}

mmocr/datasets/ser_dataset.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
3+
from typing import Callable, List, Optional, Sequence, Union
4+
5+
from mmengine.dataset import BaseDataset
6+
from transformers import AutoTokenizer
7+
8+
from mmocr.registry import DATASETS
9+
10+
11+
@DATASETS.register_module()
12+
class SERDataset(BaseDataset):
13+
14+
def __init__(self,
15+
ann_file: str = '',
16+
tokenizer: str = '',
17+
metainfo: Optional[dict] = None,
18+
data_root: Optional[str] = '',
19+
data_prefix: dict = dict(img_path=''),
20+
filter_cfg: Optional[dict] = None,
21+
indices: Optional[Union[int, Sequence[int]]] = None,
22+
serialize_data: bool = True,
23+
pipeline: List[Union[dict, Callable]] = [],
24+
test_mode: bool = False,
25+
lazy_init: bool = False,
26+
max_refetch: int = 1000) -> None:
27+
28+
if isinstance(tokenizer, str):
29+
tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=True)
30+
self.tokenizer = tokenizer
31+
32+
super().__init__(
33+
ann_file=ann_file,
34+
metainfo=metainfo,
35+
data_root=data_root,
36+
data_prefix=data_prefix,
37+
filter_cfg=filter_cfg,
38+
indices=indices,
39+
serialize_data=serialize_data,
40+
pipeline=pipeline,
41+
test_mode=test_mode,
42+
lazy_init=lazy_init,
43+
max_refetch=max_refetch)
44+
45+
def load_data_list(self) -> List[dict]:
46+
data_list = super().load_data_list()
47+
48+
# split text to several slices because of over-length
49+
input_ids, bboxes, labels = [], [], []
50+
segment_ids, position_ids = [], []
51+
image_path = []
52+
for i in range(len(data_list)):
53+
start = 0
54+
cur_iter = 0
55+
while start < len(data_list[i]['input_ids']):
56+
end = min(start + 510, len(data_list[i]['input_ids']))
57+
58+
input_ids.append([self.tokenizer.cls_token_id] +
59+
data_list[i]['input_ids'][start:end] +
60+
[self.tokenizer.sep_token_id])
61+
bboxes.append([[0, 0, 0, 0]] +
62+
data_list[i]['bboxes'][start:end] +
63+
[[1000, 1000, 1000, 1000]])
64+
labels.append([-100] + data_list[i]['labels'][start:end] +
65+
[-100])
66+
67+
cur_segment_ids = self.get_segment_ids(bboxes[-1])
68+
cur_position_ids = self.get_position_ids(cur_segment_ids)
69+
segment_ids.append(cur_segment_ids)
70+
position_ids.append(cur_position_ids)
71+
image_path.append(
72+
os.path.join(self.data_root, data_list[i]['img_path']))
73+
74+
start = end
75+
cur_iter += 1
76+
77+
assert len(input_ids) == len(bboxes) == len(labels) == len(
78+
segment_ids) == len(position_ids)
79+
assert len(segment_ids) == len(image_path)
80+
81+
return data_list
82+
83+
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
84+
instances = raw_data_info['instances']
85+
img_path = raw_data_info['img_path']
86+
width = raw_data_info['width']
87+
height = raw_data_info['height']
88+
89+
texts = instances.get('texts', None)
90+
bboxes = instances.get('bboxes', None)
91+
labels = instances.get('labels', None)
92+
assert texts or bboxes or labels
93+
# norm box
94+
bboxes_norm = [self.box_norm(box, width, height) for box in bboxes]
95+
# get label2id
96+
label2id = self.metainfo['label2id']
97+
98+
cur_doc_input_ids, cur_doc_bboxes, cur_doc_labels = [], [], []
99+
for j in range(len(texts)):
100+
cur_input_ids = self.tokenizer(
101+
texts[j],
102+
truncation=False,
103+
add_special_tokens=False,
104+
return_attention_mask=False)['input_ids']
105+
if len(cur_input_ids) == 0:
106+
continue
107+
108+
cur_label = labels[j].upper()
109+
if cur_label == 'OTHER':
110+
cur_labels = ['O'] * len(cur_input_ids)
111+
for k in range(len(cur_labels)):
112+
cur_labels[k] = label2id[cur_labels[k]]
113+
else:
114+
cur_labels = [cur_label] * len(cur_input_ids)
115+
cur_labels[0] = label2id['B-' + cur_labels[0]]
116+
for k in range(1, len(cur_labels)):
117+
cur_labels[k] = label2id['I-' + cur_labels[k]]
118+
assert len(cur_input_ids) == len(
119+
[bboxes_norm[j]] * len(cur_input_ids)) == len(cur_labels)
120+
cur_doc_input_ids += cur_input_ids
121+
cur_doc_bboxes += [bboxes_norm[j]] * len(cur_input_ids)
122+
cur_doc_labels += cur_labels
123+
assert len(cur_doc_input_ids) == len(cur_doc_bboxes) == len(
124+
cur_doc_labels)
125+
assert len(cur_doc_input_ids) > 0
126+
127+
data_info = {}
128+
data_info['img_path'] = img_path
129+
data_info['input_ids'] = cur_doc_input_ids
130+
data_info['bboxes'] = cur_doc_bboxes
131+
data_info['labels'] = cur_doc_labels
132+
return data_info
133+
134+
def box_norm(self, box, width, height):
135+
136+
def clip(min_num, num, max_num):
137+
return min(max(num, min_num), max_num)
138+
139+
x0, y0, x1, y1 = box
140+
x0 = clip(0, int((x0 / width) * 1000), 1000)
141+
y0 = clip(0, int((y0 / height) * 1000), 1000)
142+
x1 = clip(0, int((x1 / width) * 1000), 1000)
143+
y1 = clip(0, int((y1 / height) * 1000), 1000)
144+
assert x1 >= x0
145+
assert y1 >= y0
146+
return [x0, y0, x1, y1]
147+
148+
def get_segment_ids(self, bboxs):
149+
segment_ids = []
150+
for i in range(len(bboxs)):
151+
if i == 0:
152+
segment_ids.append(0)
153+
else:
154+
if bboxs[i - 1] == bboxs[i]:
155+
segment_ids.append(segment_ids[-1])
156+
else:
157+
segment_ids.append(segment_ids[-1] + 1)
158+
return segment_ids
159+
160+
def get_position_ids(self, segment_ids):
161+
position_ids = []
162+
for i in range(len(segment_ids)):
163+
if i == 0:
164+
position_ids.append(2)
165+
else:
166+
if segment_ids[i] == segment_ids[i - 1]:
167+
position_ids.append(position_ids[-1] + 1)
168+
else:
169+
position_ids.append(2)
170+
return position_ids

projects/LayoutLMv3/README.md

Whitespace-only changes.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_base_ = [
2+
'/Users/wangnu/Documents/GitHub/mmocr/'
3+
'configs/ser/_base_/datasets/xfund_zh.py'
4+
]
5+
6+
train_dataset = _base_.xfund_zh_ser_train
7+
train_dataloader = dict(
8+
batch_size=1,
9+
num_workers=8,
10+
persistent_workers=True,
11+
sampler=dict(type='DefaultSampler', shuffle=True),
12+
dataset=train_dataset)

0 commit comments

Comments
 (0)