Skip to content

Commit 75c06d3

Browse files
MountchickenHarold-lkkgaotongxiao
authored
[Dataset Preparer] Add SCUT-CTW1500 (#1677)
* update metafile and download * update parser * updata ctw1500 to new dataprepare design * add lexicon into ctw1500 textspotting * fix --------- Co-authored-by: liukuikun <[email protected]> Co-authored-by: gaotongxiao <[email protected]>
1 parent bfb36d8 commit 75c06d3

File tree

7 files changed

+312
-1
lines changed

7 files changed

+312
-1
lines changed

dataset_zoo/ctw1500/metafile.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
Name: 'CTW1500'
2+
Paper:
3+
Title: Curved scene text detection via transverse and longitudinal sequence connection
4+
URL: https://www.sciencedirect.com/science/article/pii/S0031320319300664
5+
Venue: PR
6+
Year: '2019'
7+
BibTeX: '@article{liu2019curved,
8+
title={Curved scene text detection via transverse and longitudinal sequence connection},
9+
author={Liu, Yuliang and Jin, Lianwen and Zhang, Shuaitao and Luo, Canjie and Zhang, Sheng},
10+
journal={Pattern Recognition},
11+
volume={90},
12+
pages={337--345},
13+
year={2019},
14+
publisher={Elsevier}
15+
}'
16+
Data:
17+
Website: https://github.com/Yuliang-Liu/Curve-Text-Detector
18+
Language:
19+
- English
20+
Scene:
21+
- Scene
22+
Granularity:
23+
- Word
24+
- Line
25+
Tasks:
26+
- textrecog
27+
- textdet
28+
- textspotting
29+
License:
30+
Type: N/A
31+
Link: N/A
32+
Format: .xml

dataset_zoo/ctw1500/textdet.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
data_root = 'data/ctw1500'
2+
cache_path = 'data/cache'
3+
4+
train_preparer = dict(
5+
obtainer=dict(
6+
type='NaiveDataObtainer',
7+
cache_path=cache_path,
8+
files=[
9+
dict(
10+
url='https://universityofadelaide.box.com/shared/static/'
11+
'py5uwlfyyytbb2pxzq9czvu6fuqbjdh8.zip',
12+
save_name='ctw1500_train_images.zip',
13+
md5='f1453464b764343040644464d5c0c4fa',
14+
split=['train'],
15+
content=['image'],
16+
mapping=[[
17+
'ctw1500_train_images/train_images', 'textdet_imgs/train'
18+
]]),
19+
dict(
20+
url='https://universityofadelaide.box.com/shared/static/'
21+
'jikuazluzyj4lq6umzei7m2ppmt3afyw.zip',
22+
save_name='ctw1500_train_labels.zip',
23+
md5='d9ba721b25be95c2d78aeb54f812a5b1',
24+
split=['train'],
25+
content=['annotation'],
26+
mapping=[[
27+
'ctw1500_train_labels/ctw1500_train_labels/',
28+
'annotations/train'
29+
]])
30+
]),
31+
gatherer=dict(
32+
type='PairGatherer',
33+
img_suffixes=['.jpg', '.JPG'],
34+
rule=[r'(\d{4}).jpg', r'\1.xml']),
35+
parser=dict(type='CTW1500AnnParser'),
36+
packer=dict(type='TextDetPacker'),
37+
dumper=dict(type='JsonDumper'),
38+
)
39+
40+
test_preparer = dict(
41+
obtainer=dict(
42+
type='NaiveDataObtainer',
43+
cache_path=cache_path,
44+
files=[
45+
dict(
46+
url='https://universityofadelaide.box.com/shared/static/'
47+
't4w48ofnqkdw7jyc4t11nsukoeqk9c3d.zip',
48+
save_name='ctw1500_test_images.zip',
49+
md5='79103fd77dfdd2c70ae6feb3a2fb4530',
50+
split=['test'],
51+
content=['image'],
52+
mapping=[[
53+
'ctw1500_test_images/test_images', 'textdet_imgs/test'
54+
]]),
55+
dict(
56+
url='https://cloudstor.aarnet.edu.au/plus/s/uoeFl0pCN9BOCN5/'
57+
'download',
58+
save_name='ctw1500_test_labels.zip',
59+
md5='7f650933a30cf1bcdbb7874e4962a52b',
60+
split=['test'],
61+
content=['annotation'],
62+
mapping=[['ctw1500_test_labels', 'annotations/test']])
63+
]),
64+
gatherer=dict(
65+
type='PairGatherer',
66+
img_suffixes=['.jpg', '.JPG'],
67+
rule=[r'(\d{4}).jpg', r'000\1.txt']),
68+
parser=dict(type='CTW1500AnnParser'),
69+
packer=dict(type='TextDetPacker'),
70+
dumper=dict(type='JsonDumper'),
71+
)
72+
delete = [
73+
'ctw1500_train_images', 'ctw1500_test_images', 'annotations',
74+
'ctw1500_train_labels', 'ctw1500_test_labels'
75+
]
76+
config_generator = dict(type='TextDetConfigGenerator')

dataset_zoo/ctw1500/textrecog.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = ['textdet.py']
2+
3+
_base_.train_preparer.gatherer.img_dir = 'textdet_imgs/train'
4+
_base_.test_preparer.gatherer.img_dir = 'textdet_imgs/test'
5+
6+
_base_.train_preparer.packer.type = 'TextRecogCropPacker'
7+
_base_.test_preparer.packer.type = 'TextRecogCropPacker'
8+
9+
config_generator = dict(type='TextRecogConfigGenerator')
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
_base_ = ['textdet.py']
2+
3+
_base_.train_preparer.gatherer.img_dir = 'textdet_imgs/train'
4+
_base_.test_preparer.gatherer.img_dir = 'textdet_imgs/test'
5+
6+
_base_.train_preparer.packer.type = 'TextSpottingPacker'
7+
_base_.test_preparer.packer.type = 'TextSpottingPacker'
8+
9+
_base_.test_preparer.obtainer.files.append(
10+
dict(
11+
url='https://download.openmmlab.com/mmocr/data/1.x/textspotting/'
12+
'ctw1500/lexicons.zip',
13+
save_name='ctw1500_lexicons.zip',
14+
md5='168150ca45da161917bf35a20e45b8d6',
15+
content=['lexicons'],
16+
mapping=[['ctw1500_lexicons/lexicons', 'lexicons']]))
17+
18+
_base_.delete.append('ctw1500_lexicons')
19+
config_generator = dict(type='TextSpottingConfigGenerator')

mmocr/datasets/preparers/parsers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .base import BaseParser
33
from .coco_parser import COCOTextDetAnnParser
4+
from .ctw1500_parser import CTW1500AnnParser
45
from .funsd_parser import FUNSDTextDetAnnParser
56
from .icdar_txt_parser import (ICDARTxtTextDetAnnParser,
67
ICDARTxtTextRecogAnnParser)
@@ -14,5 +15,5 @@
1415
'BaseParser', 'ICDARTxtTextDetAnnParser', 'ICDARTxtTextRecogAnnParser',
1516
'TotaltextTextDetAnnParser', 'WildreceiptKIEAnnParser',
1617
'COCOTextDetAnnParser', 'SVTTextDetAnnParser', 'FUNSDTextDetAnnParser',
17-
'SROIETextDetAnnParser', 'NAFAnnParser'
18+
'SROIETextDetAnnParser', 'NAFAnnParser', 'CTW1500AnnParser'
1819
]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import xml.etree.ElementTree as ET
3+
from typing import List, Tuple
4+
5+
import numpy as np
6+
7+
from mmocr.datasets.preparers.data_preparer import DATA_PARSERS
8+
from mmocr.datasets.preparers.parsers.base import BaseParser
9+
from mmocr.utils import list_from_file
10+
11+
12+
@DATA_PARSERS.register_module()
13+
class CTW1500AnnParser(BaseParser):
14+
"""SCUT-CTW1500 dataset parser.
15+
16+
Args:
17+
ignore (str): The text of the ignored instances. Defaults to
18+
'###'.
19+
"""
20+
21+
def __init__(self, ignore: str = '###', **kwargs) -> None:
22+
self.ignore = ignore
23+
super().__init__(**kwargs)
24+
25+
def parse_file(self, img_path: str, ann_path: str) -> Tuple:
26+
"""Convert annotation for a single image.
27+
28+
Args:
29+
img_path (str): The path of image.
30+
ann_path (str): The path of annotation.
31+
32+
Returns:
33+
Tuple: A tuple of (img_path, instance).
34+
35+
- img_path (str): The path of image file, which can be read
36+
directly by opencv.
37+
- instance: instance is a list of dict containing parsed
38+
annotations, which should contain the following keys:
39+
40+
- 'poly' or 'box' (textdet or textspotting)
41+
- 'text' (textspotting or textrecog)
42+
- 'ignore' (all task)
43+
44+
Examples:
45+
An example of returned values:
46+
>>> ('imgs/train/xxx.jpg',
47+
>>> dict(
48+
>>> poly=[[[0, 1], [1, 1], [1, 0], [0, 0]]],
49+
>>> text='hello',
50+
>>> ignore=False)
51+
>>> )
52+
"""
53+
54+
if self.split == 'train':
55+
instances = self.load_xml_info(ann_path)
56+
elif self.split == 'test':
57+
instances = self.load_txt_info(ann_path)
58+
return img_path, instances
59+
60+
def load_txt_info(self, anno_dir: str) -> List:
61+
"""Load the annotation of the SCUT-CTW dataset (test split).
62+
Args:
63+
anno_dir (str): Path to the annotation file.
64+
65+
Returns:
66+
list[Dict]: List of instances.
67+
"""
68+
instances = list()
69+
for line in list_from_file(anno_dir):
70+
# each line has one ploygen (n vetices), and one text.
71+
# e.g., 695,885,866,888,867,1146,696,1143,####Latin 9
72+
line = line.strip()
73+
strs = line.split(',')
74+
assert strs[28][0] == '#'
75+
xy = [int(x) for x in strs[0:28]]
76+
assert len(xy) == 28
77+
poly = np.array(xy).reshape(-1).tolist()
78+
text = strs[28][4:]
79+
instances.append(
80+
dict(poly=poly, text=text, ignore=text == self.ignore))
81+
return instances
82+
83+
def load_xml_info(self, anno_dir: str) -> List:
84+
"""Load the annotation of the SCUT-CTW dataset (train split).
85+
Args:
86+
anno_dir (str): Path to the annotation file.
87+
88+
Returns:
89+
list[Dict]: List of instances.
90+
"""
91+
obj = ET.parse(anno_dir)
92+
instances = list()
93+
for image in obj.getroot(): # image
94+
for box in image: # image
95+
text = box[0].text
96+
segs = box[1].text
97+
pts = segs.strip().split(',')
98+
pts = [int(x) for x in pts]
99+
assert len(pts) == 28
100+
poly = np.array(pts).reshape(-1).tolist()
101+
instances.append(dict(poly=poly, text=text, ignore=0))
102+
return instances
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os.path as osp
3+
import tempfile
4+
import unittest
5+
6+
from mmocr.datasets.preparers.parsers import CTW1500AnnParser
7+
from mmocr.utils import list_to_file
8+
9+
10+
class TestCTW1500AnnParser(unittest.TestCase):
11+
12+
def setUp(self) -> None:
13+
self.root = tempfile.TemporaryDirectory()
14+
15+
def _create_dummy_ctw1500_det(self):
16+
fake_train_anno = [
17+
'<Annotations>',
18+
' <image file="0200.jpg">',
19+
' <box height="197" left="131" top="49" width="399">',
20+
' <label>OLATHE</label>',
21+
' <segs>131,58,208,49,279,56,346,76,412,101,473,141,530,192,510,246,458,210,405,175,350,151,291,137,228,133,165,134</segs>', # noqa: E501
22+
' <pts x="183" y="95" />',
23+
' <pts x="251" y="89" />',
24+
' <pts x="322" y="107" />',
25+
' <pts x="383" y="124" />',
26+
' <pts x="441" y="161" />',
27+
' <pts x="493" y="201" />',
28+
' </box>',
29+
' </image>',
30+
'</Annotations>',
31+
]
32+
train_ann_file = osp.join(self.root.name, 'ctw1500_train.xml')
33+
list_to_file(train_ann_file, fake_train_anno)
34+
35+
fake_test_anno = [
36+
'48,84,61,79,75,73,88,68,102,74,116,79,130,84,135,73,119,67,104,60,89,56,74,61,59,67,45,73,#######', # noqa: E501
37+
'51,137,58,137,66,137,74,137,82,137,90,137,98,137,98,119,90,119,82,119,74,119,66,119,58,119,50,119,####E-313', # noqa: E501
38+
'41,155,49,155,57,155,65,155,73,155,81,155,89,155,87,136,79,136,71,136,64,136,56,136,48,136,41,137,#######', # noqa: E501
39+
'41,193,57,193,74,194,90,194,107,195,123,195,140,196,146,168,128,167,110,167,92,167,74,166,56,166,39,166,####F.D.N.Y.', # noqa: E501
40+
]
41+
test_ann_file = osp.join(self.root.name, 'ctw1500_test.txt')
42+
list_to_file(test_ann_file, fake_test_anno)
43+
return (osp.join(self.root.name,
44+
'ctw1500.jpg'), train_ann_file, test_ann_file)
45+
46+
def test_textdet_parsers(self):
47+
parser = CTW1500AnnParser(split='train')
48+
img_path, train_file, test_file = self._create_dummy_ctw1500_det()
49+
img_path, instances = parser.parse_file(img_path, train_file)
50+
self.assertEqual(img_path, osp.join(self.root.name, 'ctw1500.jpg'))
51+
self.assertEqual(len(instances), 1)
52+
self.assertEqual(instances[0]['text'], 'OLATHE')
53+
self.assertEqual(instances[0]['poly'], [
54+
131, 58, 208, 49, 279, 56, 346, 76, 412, 101, 473, 141, 530, 192,
55+
510, 246, 458, 210, 405, 175, 350, 151, 291, 137, 228, 133, 165,
56+
134
57+
])
58+
self.assertEqual(instances[0]['ignore'], False)
59+
60+
parser = CTW1500AnnParser(split='test')
61+
img_path, instances = parser.parse_file(img_path, test_file)
62+
self.assertEqual(img_path, osp.join(self.root.name, 'ctw1500.jpg'))
63+
self.assertEqual(len(instances), 4)
64+
self.assertEqual(instances[0]['ignore'], True)
65+
self.assertEqual(instances[1]['text'], 'E-313')
66+
self.assertEqual(instances[3]['poly'], [
67+
41, 193, 57, 193, 74, 194, 90, 194, 107, 195, 123, 195, 140, 196,
68+
146, 168, 128, 167, 110, 167, 92, 167, 74, 166, 56, 166, 39, 166
69+
])
70+
71+
def tearDown(self) -> None:
72+
self.root.cleanup()

0 commit comments

Comments
 (0)