Skip to content

Commit 7e9f775

Browse files
[Feature] Add svtr decoder (#1448)
* add svtr decoder * svtr decoder * update Co-authored-by: gaotongxiao <[email protected]>
1 parent 53e72e4 commit 7e9f775

File tree

3 files changed

+192
-1
lines changed

3 files changed

+192
-1
lines changed

mmocr/models/textrecog/decoders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder
1313
from .sar_decoder_with_bs import ParallelSARDecoderWithBS
1414
from .sequence_attention_decoder import SequenceAttentionDecoder
15+
from .svtr_decoder import SVTRDecoder
1516

1617
__all__ = [
1718
'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder',
1819
'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder',
1920
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
2021
'ABILanguageDecoder', 'ABIVisionDecoder', 'MasterDecoder',
21-
'RobustScannerFuser', 'ABIFuser', 'ASTERDecoder'
22+
'RobustScannerFuser', 'ABIFuser', 'SVTRDecoder', 'ASTERDecoder'
2223
]
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Dict, List, Optional, Sequence, Union
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from mmocr.models.common.dictionary import Dictionary
8+
from mmocr.registry import MODELS
9+
from mmocr.structures import TextRecogDataSample
10+
from .base import BaseDecoder
11+
12+
13+
@MODELS.register_module()
14+
class SVTRDecoder(BaseDecoder):
15+
"""Decoder module in `SVTR <https://arxiv.org/abs/2205.00159>`_.
16+
17+
Args:
18+
in_channels (int): The num of input channels.
19+
dictionary (Union[Dict, Dictionary]): The config for `Dictionary` or
20+
the instance of `Dictionary`. Defaults to None.
21+
module_loss (Optional[Dict], optional): Cfg to build module_loss.
22+
Defaults to None.
23+
postprocessor (Optional[Dict], optional): Cfg to build postprocessor.
24+
Defaults to None.
25+
max_seq_len (int, optional): Maximum output sequence length :math:`T`.
26+
Defaults to 25.
27+
init_cfg (dict or list[dict], optional): Initialization configs.
28+
Defaults to None.
29+
"""
30+
31+
def __init__(self,
32+
in_channels: int,
33+
dictionary: Union[Dict, Dictionary] = None,
34+
module_loss: Optional[Dict] = None,
35+
postprocessor: Optional[Dict] = None,
36+
max_seq_len: int = 25,
37+
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
38+
39+
super().__init__(
40+
dictionary=dictionary,
41+
module_loss=module_loss,
42+
postprocessor=postprocessor,
43+
max_seq_len=max_seq_len,
44+
init_cfg=init_cfg)
45+
46+
self.decoder = nn.Linear(
47+
in_features=in_channels, out_features=self.dictionary.num_classes)
48+
self.softmax = nn.Softmax(dim=-1)
49+
50+
def forward_train(
51+
self,
52+
feat: Optional[torch.Tensor] = None,
53+
out_enc: Optional[torch.Tensor] = None,
54+
data_samples: Optional[Sequence[TextRecogDataSample]] = None
55+
) -> torch.Tensor:
56+
"""Forward for training.
57+
58+
Args:
59+
feat (torch.Tensor, optional): The feature map from backbone of
60+
shape :math:`(N, E, H, W)`. Defaults to None.
61+
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
62+
data_samples (Sequence[TextRecogDataSample]): Batch of
63+
TextRecogDataSample, containing gt_text information. Defaults
64+
to None.
65+
66+
Returns:
67+
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
68+
:math:`C` is ``num_classes``.
69+
"""
70+
assert feat.size(2) == 1, 'feature height must be 1'
71+
x = feat.squeeze(2)
72+
x = x.permute(0, 2, 1)
73+
predicts = self.decoder(x)
74+
return predicts
75+
76+
def forward_test(
77+
self,
78+
feat: Optional[torch.Tensor] = None,
79+
out_enc: Optional[torch.Tensor] = None,
80+
data_samples: Optional[Sequence[TextRecogDataSample]] = None
81+
) -> torch.Tensor:
82+
"""Forward for testing.
83+
84+
Args:
85+
feat (torch.Tensor, optional): The feature map from backbone of
86+
shape :math:`(N, E, H, W)`. Defaults to None.
87+
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
88+
data_samples (Sequence[TextRecogDataSample]): Batch of
89+
TextRecogDataSample, containing gt_text information. Defaults
90+
to None.
91+
Returns:
92+
Tensor: Character probabilities. of shape
93+
:math:`(N, self.max_seq_len, C)` where :math:`C` is
94+
``num_classes``.
95+
"""
96+
return self.softmax(self.forward_train(feat, out_enc, data_samples))
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os.path as osp
3+
import tempfile
4+
from unittest import TestCase
5+
6+
import torch
7+
from mmengine.structures import LabelData
8+
9+
from mmocr.models.textrecog.decoders.svtr_decoder import SVTRDecoder
10+
from mmocr.structures import TextRecogDataSample
11+
from mmocr.testing import create_dummy_dict_file
12+
13+
14+
class TestSVTRDecoder(TestCase):
15+
16+
def setUp(self):
17+
gt_text_sample1 = TextRecogDataSample()
18+
gt_text = LabelData()
19+
gt_text.item = 'Hello'
20+
gt_text_sample1.gt_text = gt_text
21+
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))
22+
23+
gt_text_sample2 = TextRecogDataSample()
24+
gt_text = LabelData()
25+
gt_text = LabelData()
26+
gt_text.item = 'World'
27+
gt_text_sample2.gt_text = gt_text
28+
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))
29+
30+
self.data_info = [gt_text_sample1, gt_text_sample2]
31+
32+
def test_init(self):
33+
with tempfile.TemporaryDirectory() as tmp_dir:
34+
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
35+
create_dummy_dict_file(dict_file)
36+
dict_cfg = dict(
37+
type='Dictionary',
38+
dict_file=dict_file,
39+
with_start=True,
40+
with_end=True,
41+
same_start_end=True,
42+
with_padding=True,
43+
with_unknown=True)
44+
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower')
45+
SVTRDecoder(
46+
in_channels=192, dictionary=dict_cfg, module_loss=loss_cfg)
47+
48+
def test_forward_train(self):
49+
feat = torch.randn(1, 192, 1, 25)
50+
tmp_dir = tempfile.TemporaryDirectory()
51+
max_seq_len = 25
52+
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
53+
create_dummy_dict_file(dict_file)
54+
dict_cfg = dict(
55+
type='Dictionary',
56+
dict_file=dict_file,
57+
with_start=True,
58+
with_end=True,
59+
same_start_end=True,
60+
with_padding=True,
61+
with_unknown=True)
62+
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower')
63+
decoder = SVTRDecoder(
64+
in_channels=192,
65+
dictionary=dict_cfg,
66+
module_loss=loss_cfg,
67+
max_seq_len=max_seq_len,
68+
)
69+
data_samples = decoder.module_loss.get_targets(self.data_info)
70+
output = decoder.forward_train(feat=feat, data_samples=data_samples)
71+
self.assertTupleEqual(tuple(output.shape), (1, max_seq_len, 39))
72+
73+
def test_forward_test(self):
74+
feat = torch.randn(1, 192, 1, 25)
75+
tmp_dir = tempfile.TemporaryDirectory()
76+
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
77+
create_dummy_dict_file(dict_file)
78+
# test diction cfg
79+
dict_cfg = dict(
80+
type='Dictionary',
81+
dict_file=dict_file,
82+
with_start=True,
83+
with_end=True,
84+
same_start_end=True,
85+
with_padding=True,
86+
with_unknown=True)
87+
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower')
88+
decoder = SVTRDecoder(
89+
in_channels=192,
90+
dictionary=dict_cfg,
91+
module_loss=loss_cfg,
92+
max_seq_len=25)
93+
output = decoder.forward_test(feat=feat, data_samples=self.data_info)
94+
self.assertTupleEqual(tuple(output.shape), (1, 25, 39))

0 commit comments

Comments
 (0)