Skip to content

Commit 65e746e

Browse files
authored
[UT] Add missing unit tests (#1651)
* update * remove code
1 parent 7e9f775 commit 65e746e

File tree

5 files changed

+140
-30
lines changed

5 files changed

+140
-30
lines changed

mmocr/models/textdet/necks/fpem_ffm.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Dict, List, Optional, Tuple, Union
3+
4+
import torch
25
import torch.nn.functional as F
36
from mmengine.model import BaseModule, ModuleList
47
from torch import nn
@@ -14,7 +17,9 @@ class FPEM(BaseModule):
1417
init_cfg (dict or list[dict], optional): Initialization configs.
1518
"""
1619

17-
def __init__(self, in_channels=128, init_cfg=None):
20+
def __init__(self,
21+
in_channels: int = 128,
22+
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
1823
super().__init__(init_cfg=init_cfg)
1924
self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
2025
self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
@@ -23,7 +28,8 @@ def __init__(self, in_channels=128, init_cfg=None):
2328
self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
2429
self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
2530

26-
def forward(self, c2, c3, c4, c5):
31+
def forward(self, c2: torch.Tensor, c3: torch.Tensor, c4: torch.Tensor,
32+
c5: torch.Tensor) -> List[torch.Tensor]:
2733
"""
2834
Args:
2935
c2, c3, c4, c5 (Tensor): Each has the shape of
@@ -48,8 +54,21 @@ def _upsample_add(self, x, y):
4854

4955

5056
class SeparableConv2d(BaseModule):
57+
"""Implementation of separable convolution, which is consisted of depthwise
58+
convolution and pointwise convolution.
59+
60+
Args:
61+
in_channels (int): Number of input channels.
62+
out_channels (int): Number of output channels.
63+
stride (int): Stride of the depthwise convolution.
64+
init_cfg (dict or list[dict], optional): Initialization configs.
65+
"""
5166

52-
def __init__(self, in_channels, out_channels, stride=1, init_cfg=None):
67+
def __init__(self,
68+
in_channels: int,
69+
out_channels: int,
70+
stride: int = 1,
71+
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
5372
super().__init__(init_cfg=init_cfg)
5473

5574
self.depthwise_conv = nn.Conv2d(
@@ -64,7 +83,15 @@ def __init__(self, in_channels, out_channels, stride=1, init_cfg=None):
6483
self.bn = nn.BatchNorm2d(out_channels)
6584
self.relu = nn.ReLU()
6685

67-
def forward(self, x):
86+
def forward(self, x: torch.Tensor) -> torch.Tensor:
87+
"""Forward function.
88+
89+
Args:
90+
x (Tensor): Input tensor.
91+
92+
Returns:
93+
Tensor: Output tensor.
94+
"""
6895
x = self.depthwise_conv(x)
6996
x = self.pointwise_conv(x)
7097
x = self.bn(x)
@@ -85,13 +112,15 @@ class FPEM_FFM(BaseModule):
85112
init_cfg (dict or list[dict], optional): Initialization configs.
86113
"""
87114

88-
def __init__(self,
89-
in_channels,
90-
conv_out=128,
91-
fpem_repeat=2,
92-
align_corners=False,
93-
init_cfg=dict(
94-
type='Xavier', layer='Conv2d', distribution='uniform')):
115+
def __init__(
116+
self,
117+
in_channels: List[int],
118+
conv_out: int = 128,
119+
fpem_repeat: int = 2,
120+
align_corners: bool = False,
121+
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
122+
type='Xavier', layer='Conv2d', distribution='uniform')
123+
) -> None:
95124
super().__init__(init_cfg=init_cfg)
96125
# reduce layers
97126
self.reduce_conv_c2 = nn.Sequential(
@@ -119,7 +148,7 @@ def __init__(self,
119148
for _ in range(fpem_repeat):
120149
self.fpems.append(FPEM(conv_out))
121150

122-
def forward(self, x):
151+
def forward(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor]:
123152
"""
124153
Args:
125154
x (list[Tensor]): A list of four tensors of shape
@@ -128,7 +157,7 @@ def forward(self, x):
128157
``in_channels``.
129158
130159
Returns:
131-
list[Tensor]: Four tensors of shape
160+
tuple[Tensor]: Four tensors of shape
132161
:math:`(N, C_{out}, H_0, W_0)` where :math:`C_{out}` is
133162
``conv_out``.
134163
"""

mmocr/utils/__init__.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .bbox_utils import (bbox2poly, bbox_center_distance, bbox_diag_distance,
3-
bezier2polygon, is_on_same_line, rescale_bboxes,
4-
stitch_boxes_into_lines)
3+
bezier2polygon, is_on_same_line, rescale_bbox,
4+
rescale_bboxes, stitch_boxes_into_lines)
55
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
66
is_type_list, valid_boundary)
77
from .collect_env import collect_env
@@ -34,17 +34,18 @@
3434
'is_2dlist', 'valid_boundary', 'list_to_file', 'list_from_file',
3535
'is_on_same_line', 'stitch_boxes_into_lines', 'StringStripper',
3636
'bezier2polygon', 'sort_points', 'dump_ocr_data', 'recog_anno_to_imginfo',
37-
'rescale_polygons', 'rescale_polygon', 'rescale_bboxes', 'bbox2poly',
38-
'crop_polygon', 'is_poly_inside_rect', 'poly2bbox', 'poly_intersection',
39-
'poly_iou', 'poly_make_valid', 'poly_union', 'poly2shapely',
40-
'polys2shapely', 'register_all_modules', 'offset_polygon', 'sort_vertex8',
41-
'sort_vertex', 'bbox_center_distance', 'bbox_diag_distance',
42-
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
43-
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img',
44-
'ConfigType', 'DetSampleList', 'RecForwardResults', 'InitConfigType',
45-
'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', 'OptMultiConfig',
46-
'OptRecSampleList', 'RecSampleList', 'MultiConfig', 'OptTensor',
47-
'ColorType', 'OptKIESampleList', 'KIESampleList', 'is_archive',
48-
'check_integrity', 'list_files', 'get_md5', 'InstanceList', 'LabelList',
49-
'OptInstanceList', 'OptLabelList', 'RangeType', 'remove_pipeline_elements'
37+
'rescale_polygons', 'rescale_polygon', 'rescale_bbox', 'rescale_bboxes',
38+
'bbox2poly', 'crop_polygon', 'is_poly_inside_rect', 'poly2bbox',
39+
'poly_intersection', 'poly_iou', 'poly_make_valid', 'poly_union',
40+
'poly2shapely', 'polys2shapely', 'register_all_modules', 'offset_polygon',
41+
'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
42+
'bbox_diag_distance', 'boundary_iou', 'point_distance', 'points_center',
43+
'fill_hole', 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img',
44+
'warp_img', 'ConfigType', 'DetSampleList', 'RecForwardResults',
45+
'InitConfigType', 'OptConfigType', 'OptDetSampleList', 'OptInitConfigType',
46+
'OptMultiConfig', 'OptRecSampleList', 'RecSampleList', 'MultiConfig',
47+
'OptTensor', 'ColorType', 'OptKIESampleList', 'KIESampleList',
48+
'is_archive', 'check_integrity', 'list_files', 'get_md5', 'InstanceList',
49+
'LabelList', 'OptInstanceList', 'OptLabelList', 'RangeType',
50+
'remove_pipeline_elements'
5051
]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import unittest
3+
4+
import torch
5+
6+
from mmocr.models.textdet.necks.fpem_ffm import FPEM, FPEM_FFM
7+
8+
9+
class TestFPEM(unittest.TestCase):
10+
11+
def setUp(self):
12+
self.c2 = torch.Tensor(1, 8, 64, 64)
13+
self.c3 = torch.Tensor(1, 8, 32, 32)
14+
self.c4 = torch.Tensor(1, 8, 16, 16)
15+
self.c5 = torch.Tensor(1, 8, 8, 8)
16+
self.fpem = FPEM(in_channels=8)
17+
18+
def test_forward(self):
19+
neck = FPEM(in_channels=8)
20+
neck.init_weights()
21+
out = neck(self.c2, self.c3, self.c4, self.c5)
22+
self.assertTrue(out[0].shape == self.c2.shape)
23+
self.assertTrue(out[1].shape == self.c3.shape)
24+
self.assertTrue(out[2].shape == self.c4.shape)
25+
self.assertTrue(out[3].shape == self.c5.shape)
26+
27+
28+
class TestFPEM_FFM(unittest.TestCase):
29+
30+
def setUp(self):
31+
self.c2 = torch.Tensor(1, 8, 64, 64)
32+
self.c3 = torch.Tensor(1, 16, 32, 32)
33+
self.c4 = torch.Tensor(1, 32, 16, 16)
34+
self.c5 = torch.Tensor(1, 64, 8, 8)
35+
self.in_channels = [8, 16, 32, 64]
36+
self.conv_out = 8
37+
self.features = [self.c2, self.c3, self.c4, self.c5]
38+
39+
def test_forward(self):
40+
neck = FPEM_FFM(in_channels=self.in_channels, conv_out=self.conv_out)
41+
neck.init_weights()
42+
out = neck(self.features)
43+
self.assertTrue(out[0].shape == torch.Size([1, 8, 64, 64]))
44+
self.assertTrue(out[1].shape == out[0].shape)
45+
self.assertTrue(out[2].shape == out[0].shape)
46+
self.assertTrue(out[3].shape == out[0].shape)

tests/test_utils/test_bbox_utils.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import torch
66

77
from mmocr.utils import (bbox2poly, bbox_center_distance, bbox_diag_distance,
8-
bezier2polygon, is_on_same_line,
9-
stitch_boxes_into_lines)
8+
bezier2polygon, is_on_same_line, rescale_bbox,
9+
rescale_bboxes, stitch_boxes_into_lines)
1010
from mmocr.utils.bbox_utils import bbox_jitter
1111

1212

@@ -236,3 +236,31 @@ def test_stitch_boxes_into_lines(self):
236236
result.sort(key=lambda x: x['box'][0])
237237
expected_result.sort(key=lambda x: x['box'][0])
238238
self.assertEqual(result, expected_result)
239+
240+
241+
class TestRescaleBbox(unittest.TestCase):
242+
243+
def setUp(self) -> None:
244+
self.bbox = np.array([0, 0, 1, 1])
245+
self.bboxes = np.array([[0, 0, 1, 1], [1, 1, 2, 2]])
246+
self.scale = 2
247+
248+
def test_rescale_bbox(self):
249+
# mul
250+
rescaled_bbox = rescale_bbox(self.bbox, self.scale, mode='mul')
251+
self.assertTrue(np.allclose(rescaled_bbox, np.array([0, 0, 2, 2])))
252+
# div
253+
rescaled_bbox = rescale_bbox(self.bbox, self.scale, mode='div')
254+
self.assertTrue(np.allclose(rescaled_bbox, np.array([0, 0, 0.5, 0.5])))
255+
256+
def test_rescale_bboxes(self):
257+
# mul
258+
rescaled_bboxes = rescale_bboxes(self.bboxes, self.scale, mode='mul')
259+
self.assertTrue(
260+
np.allclose(rescaled_bboxes, np.array([[0, 0, 2, 2], [2, 2, 4,
261+
4]])))
262+
# div
263+
rescaled_bboxes = rescale_bboxes(self.bboxes, self.scale, mode='div')
264+
self.assertTrue(
265+
np.allclose(rescaled_bboxes,
266+
np.array([[0, 0, 0.5, 0.5], [0.5, 0.5, 1, 1]])))

tests/test_utils/test_check_argument.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ def test_valid_boundary():
4646
assert utils.valid_boundary(x, False)
4747
x = [0, 0, 1, 0, 1, 1, 0, 1, 1]
4848
assert utils.valid_boundary(x, True)
49+
50+
51+
def test_equal_len():
52+
53+
assert utils.equal_len([1, 2, 3], [1, 2, 3])
54+
assert not utils.equal_len([1, 2, 3], [1, 2, 3, 4])

0 commit comments

Comments
 (0)