Skip to content

Commit 4d5ed98

Browse files
authored
[Feature] ABCNet train (#1610)
* abcnet train * fix comment * updata link * fix lint * fix name
1 parent 5dbacfe commit 4d5ed98

File tree

9 files changed

+490
-35
lines changed

9 files changed

+490
-35
lines changed

mmocr/models/textdet/heads/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def loss_and_predict(self, x: Tuple[Tensor], data_samples: DetSampleList
108108
outs = self(x, data_samples)
109109
losses = self.module_loss(outs, data_samples)
110110

111-
predictions = self.postprocessor(outs, data_samples)
111+
predictions = self.postprocessor(outs, data_samples, self.training)
112112
return losses, predictions
113113

114114
def predict(self, x: torch.Tensor,
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .abcnet import ABCNet
33
from .abcnet_det_head import ABCNetDetHead
4+
from .abcnet_det_module_loss import ABCNetDetModuleLoss
45
from .abcnet_det_postprocessor import ABCNetDetPostprocessor
56
from .abcnet_postprocessor import ABCNetPostprocessor
67
from .abcnet_rec import ABCNetRec
78
from .abcnet_rec_backbone import ABCNetRecBackbone
89
from .abcnet_rec_decoder import ABCNetRecDecoder
910
from .abcnet_rec_encoder import ABCNetRecEncoder
1011
from .bezier_roi_extractor import BezierRoIExtractor
11-
from .only_rec_roi_head import OnlyRecRoIHead
12+
from .rec_roi_head import RecRoIHead
1213

1314
__all__ = [
1415
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',
1516
'ABCNetRecDecoder', 'ABCNetRecEncoder', 'ABCNet', 'ABCNetRec',
16-
'BezierRoIExtractor', 'OnlyRecRoIHead', 'ABCNetPostprocessor'
17+
'BezierRoIExtractor', 'RecRoIHead', 'ABCNetPostprocessor',
18+
'ABCNetDetModuleLoss'
1719
]

projects/ABCNet/abcnet/model/abcnet_det_module_loss.py

Lines changed: 359 additions & 0 deletions
Large diffs are not rendered by default.

projects/ABCNet/abcnet/model/only_rec_roi_head.py renamed to projects/ABCNet/abcnet/model/rec_roi_head.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from typing import Tuple
33

4+
from mmengine.structures import LabelData
45
from torch import Tensor
56

67
from mmocr.registry import MODELS, TASK_UTILS
@@ -10,7 +11,7 @@
1011

1112

1213
@MODELS.register_module()
13-
class OnlyRecRoIHead(BaseRoIHead):
14+
class RecRoIHead(BaseRoIHead):
1415
"""Simplest base roi head including one bbox head and one mask head."""
1516

1617
def __init__(self,
@@ -39,8 +40,17 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
3940
Returns:
4041
dict[str, Tensor]: A dictionary of loss components
4142
"""
43+
proposals = [
44+
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
45+
]
4246

43-
pass
47+
proposals = [p for p in proposals if len(p) > 0]
48+
bbox_feats = self.roi_extractor(inputs, proposals)
49+
rec_data_samples = [
50+
TextRecogDataSample(gt_text=LabelData(item=text))
51+
for proposal in proposals for text in proposal.texts
52+
]
53+
return self.rec_head.loss(bbox_feats, rec_data_samples)
4454

4555
def predict(self, inputs: Tuple[Tensor],
4656
data_samples: DetSampleList) -> RecSampleList:

projects/ABCNet/abcnet/model/two_stage_text_spotting.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@ def extract_feat(self, img):
7070

7171
def loss(self, inputs: torch.Tensor,
7272
data_samples: OptDetSampleList) -> Dict:
73-
pass
73+
losses = dict()
74+
inputs = self.extract_feat(inputs)
75+
det_loss, data_samples = self.det_head.loss_and_predict(
76+
inputs, data_samples)
77+
roi_losses = self.roi_head.loss(inputs, data_samples)
78+
losses.update(det_loss)
79+
losses.update(roi_losses)
80+
return losses
7481

7582
def predict(self, inputs: torch.Tensor,
7683
data_samples: OptDetSampleList) -> OptDetSampleList:
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# optimizer
2+
optim_wrapper = dict(
3+
type='OptimWrapper',
4+
optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001),
5+
clip_grad=dict(type='value', clip_value=1))
6+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=20)
7+
val_cfg = dict(type='ValLoop')
8+
test_cfg = dict(type='TestLoop')
9+
# learning policy
10+
param_scheduler = [
11+
dict(type='LinearLR', end=1000, start_factor=0.001, by_epoch=False),
12+
]

projects/ABCNet/config/abcnet/_base_abcnet-det_resnet50_fpn.py renamed to projects/ABCNet/config/abcnet/_base_abcnet_resnet50_fpn.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,37 @@
6767
std=0.01,
6868
bias=-4.59511985013459), # -log((1-p)/p) where p=0.01
6969
),
70-
module_loss=None,
70+
module_loss=dict(
71+
type='ABCNetDetModuleLoss',
72+
num_classes=num_classes,
73+
strides=strides,
74+
center_sampling=True,
75+
center_sample_radius=1.5,
76+
bbox_coder=bbox_coder,
77+
norm_on_bbox=norm_on_bbox,
78+
loss_cls=dict(
79+
type='mmdet.FocalLoss',
80+
use_sigmoid=use_sigmoid_cls,
81+
gamma=2.0,
82+
alpha=0.25,
83+
loss_weight=1.0),
84+
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0),
85+
loss_centerness=dict(
86+
type='mmdet.CrossEntropyLoss',
87+
use_sigmoid=True,
88+
loss_weight=1.0)),
7189
postprocessor=dict(
7290
type='ABCNetDetPostprocessor',
73-
# rescale_fields=['polygons', 'bboxes'],
7491
use_sigmoid_cls=use_sigmoid_cls,
7592
strides=[8, 16, 32, 64, 128],
7693
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
7794
with_bezier=True,
7895
test_cfg=dict(
79-
# rescale_fields=['polygon', 'bboxes', 'bezier'],
8096
nms_pre=1000,
8197
nms=dict(type='nms', iou_threshold=0.5),
8298
score_thr=0.3))),
8399
roi_head=dict(
84-
type='OnlyRecRoIHead',
100+
type='RecRoIHead',
85101
roi_extractor=dict(
86102
type='BezierRoIExtractor',
87103
roi_layer=dict(
@@ -95,7 +111,14 @@
95111
decoder=dict(
96112
type='ABCNetRecDecoder',
97113
dictionary=dictionary,
98-
postprocessor=dict(type='AttentionPostprocessor'),
114+
postprocessor=dict(
115+
type='AttentionPostprocessor',
116+
ignore_chars=['padding', 'unknown']),
117+
module_loss=dict(
118+
type='CEModuleLoss',
119+
ignore_first_char=False,
120+
ignore_char=-1,
121+
reduction='mean'),
99122
max_seq_len=25))),
100123
postprocessor=dict(
101124
type='ABCNetPostprocessor',
@@ -118,3 +141,32 @@
118141
type='PackTextDetInputs',
119142
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
120143
]
144+
145+
train_pipeline = [
146+
dict(
147+
type='LoadImageFromFile',
148+
file_client_args=file_client_args,
149+
color_type='color_ignore_orientation'),
150+
dict(
151+
type='LoadOCRAnnotations',
152+
with_polygon=True,
153+
with_bbox=True,
154+
with_label=True,
155+
with_text=True),
156+
dict(type='RemoveIgnored'),
157+
dict(type='RandomCrop', min_side_ratio=0.1),
158+
dict(
159+
type='RandomRotate',
160+
max_angle=30,
161+
pad_with_fixed_color=True,
162+
use_canvas=True),
163+
dict(
164+
type='RandomChoiceResize',
165+
scales=[(980, 2900), (1044, 2900), (1108, 2900), (1172, 2900),
166+
(1236, 2900), (1300, 2900), (1364, 2900), (1428, 2900),
167+
(1492, 2900)],
168+
keep_ratio=True),
169+
dict(
170+
type='PackTextDetInputs',
171+
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
172+
]

projects/ABCNet/config/abcnet/abcnet_resnet50_fpn.py

Lines changed: 0 additions & 24 deletions
This file was deleted.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
_base_ = [
2+
'_base_abcnet_resnet50_fpn.py',
3+
'../_base_/datasets/icdar2015.py',
4+
'../_base_/default_runtime.py',
5+
'../_base_/schedules/schedule_sgd_500e.py',
6+
]
7+
8+
# dataset settings
9+
icdar2015_textspotting_train = _base_.icdar2015_textspotting_train
10+
icdar2015_textspotting_train.pipeline = _base_.train_pipeline
11+
icdar2015_textspotting_test = _base_.icdar2015_textspotting_test
12+
icdar2015_textspotting_test.pipeline = _base_.test_pipeline
13+
14+
train_dataloader = dict(
15+
batch_size=2,
16+
num_workers=8,
17+
persistent_workers=True,
18+
sampler=dict(type='DefaultSampler', shuffle=True),
19+
dataset=icdar2015_textspotting_train)
20+
21+
val_dataloader = dict(
22+
batch_size=1,
23+
num_workers=4,
24+
persistent_workers=True,
25+
sampler=dict(type='DefaultSampler', shuffle=False),
26+
dataset=icdar2015_textspotting_test)
27+
28+
test_dataloader = val_dataloader
29+
30+
val_cfg = dict(type='ValLoop')
31+
test_cfg = dict(type='TestLoop')
32+
33+
custom_imports = dict(imports=['abcnet'], allow_failed_imports=False)
34+
35+
load_from = 'https://download.openmmlab.com/mmocr/textspotting/abcnet/abcnet_resnet50_fpn_500e_icdar2015/abcnet_resnet50_fpn_pretrain-d060636c.pth' # noqa
36+
37+
find_unused_parameters = True

0 commit comments

Comments
 (0)