Skip to content

Commit 3b6fed1

Browse files
committed
abcnetv2 train
1 parent bf41194 commit 3b6fed1

File tree

13 files changed

+188
-24
lines changed

13 files changed

+188
-24
lines changed

projects/ABCNet/abcnet/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .bifpn import BiFPN
1313
from .coordinate_head import CoordinateHead
1414
from .rec_roi_head import RecRoIHead
15+
from .task_utils import * # noqa: F401,F403
1516

1617
__all__ = [
1718
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',

projects/ABCNet/abcnet/model/abcnet_det_head.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def forward_single(self, x, scale, stride):
181181
# float to avoid overflow when enabling FP16
182182
if self.use_scale:
183183
bbox_pred = scale(bbox_pred).float()
184-
else:
185-
bbox_pred = bbox_pred.float()
184+
# else:
185+
# bbox_pred = bbox_pred.float()
186186
if self.norm_on_bbox:
187187
# bbox_pred needed for gradient computation has been modified
188188
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace

projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __call__(self, pred_results, data_samples, training: bool = False):
216216
Returns:
217217
list[TextDetDataSample]: Batch of post-processed datasamples.
218218
"""
219-
if training:
220-
return data_samples
219+
# if training:
220+
# return data_samples
221221
cfg = self.train_cfg if training else self.test_cfg
222222
if cfg is None:
223223
cfg = {}

projects/ABCNet/abcnet/model/bezier_roi_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, feats: Tuple[Tensor],
8888
# convert fp32 to fp16 when amp is on
8989
rois = rois.type_as(feats[0])
9090
out_size = self.roi_layers[0].output_size
91-
feats = feats[:3]
91+
# feats = feats[:3]
9292
num_levels = len(feats)
9393
roi_feats = feats[0].new_zeros(
9494
rois.size(0), self.out_channels, *out_size)

projects/ABCNet/abcnet/model/bifpn.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ def __init__(self,
170170
self.bifpn_convs = nn.ModuleList()
171171
# weighted
172172
self.weight_two_nodes = nn.Parameter(
173-
torch.Tensor(2, levels).fill_(init))
173+
torch.Tensor(2, levels).fill_(init), requires_grad=True)
174+
174175
self.weight_three_nodes = nn.Parameter(
175-
torch.Tensor(3, levels - 2).fill_(init))
176-
self.relu = nn.ReLU()
176+
torch.Tensor(3, levels - 2).fill_(init), requires_grad=True)
177+
178+
# self.relu = nn.ReLU(inplace=False)
177179
for _ in range(2):
178180
for _ in range(self.levels - 1): # 1,2,3
179181
fpn_conv = nn.Sequential(
@@ -193,9 +195,10 @@ def forward(self, inputs):
193195
# build top-down and down-top path with stack
194196
levels = self.levels
195197
# w relu
196-
w1 = self.relu(self.weight_two_nodes)
197-
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
198-
w2 = self.relu(self.weight_three_nodes)
198+
199+
_w1 = F.relu(self.weight_two_nodes)
200+
w1 = _w1 / (torch.sum(_w1, dim=0) + self.eps) # normalize
201+
w2 = F.relu(self.weight_three_nodes)
199202
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
200203
# build top-down
201204
idx_bifpn = 0

projects/ABCNet/abcnet/model/rec_roi_head.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Tuple
2+
from typing import Optional, Sequence, Tuple
33

44
from mmengine.structures import LabelData
55
from torch import Tensor
@@ -15,14 +15,19 @@ class RecRoIHead(BaseRoIHead):
1515
"""Simplest base roi head including one bbox head and one mask head."""
1616

1717
def __init__(self,
18-
neck=None,
18+
inputs_indices: Optional[Sequence] = None,
19+
neck: OptMultiConfig = None,
20+
assigner: OptMultiConfig = None,
1921
sampler: OptMultiConfig = None,
2022
roi_extractor: OptMultiConfig = None,
2123
rec_head: OptMultiConfig = None,
2224
init_cfg=None):
2325
super().__init__(init_cfg)
24-
if sampler is not None:
25-
self.sampler = TASK_UTILS.build(sampler)
26+
self.inputs_indices = inputs_indices
27+
self.assigner = assigner
28+
if assigner is not None:
29+
self.assigner = TASK_UTILS.build(assigner)
30+
self.sampler = TASK_UTILS.build(sampler)
2631
if neck is not None:
2732
self.neck = MODELS.build(neck)
2833
self.roi_extractor = MODELS.build(roi_extractor)
@@ -43,11 +48,39 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
4348
Returns:
4449
dict[str, Tensor]: A dictionary of loss components
4550
"""
46-
proposals = [
47-
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
48-
]
51+
52+
if self.inputs_indices is not None:
53+
inputs = [inputs[i] for i in self.inputs_indices]
54+
# proposals = [
55+
# ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
56+
# ]
57+
proposals = list()
58+
for ds in data_samples:
59+
pred_instances = ds.pred_instances
60+
gt_instances = ds.gt_instances
61+
# # assign
62+
# gt_beziers = gt_instances.beziers
63+
# pred_beziers = pred_instances.beziers
64+
# assign_index = [
65+
# int(
66+
# torch.argmin(
67+
# torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
68+
# for i in range(len(pred_beziers))
69+
# ]
70+
# proposal = InstanceData()
71+
# proposal.texts = gt_instances.texts + gt_instances[
72+
# assign_index].texts
73+
# proposal.beziers = torch.cat(
74+
# [gt_instances.beziers, pred_instances.beziers], dim=0)
75+
if self.assigner:
76+
gt_instances, pred_instances = self.assigner.assign(
77+
gt_instances, pred_instances)
78+
proposal = self.sampler.sample(gt_instances, pred_instances)
79+
proposals.append(proposal)
4980

5081
proposals = [p for p in proposals if len(p) > 0]
82+
if hasattr(self, 'neck') and self.neck is not None:
83+
inputs = self.neck(inputs)
5184
bbox_feats = self.roi_extractor(inputs, proposals)
5285
rec_data_samples = [
5386
TextRecogDataSample(gt_text=LabelData(item=text))
@@ -57,6 +90,7 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
5790

5891
def predict(self, inputs: Tuple[Tensor],
5992
data_samples: DetSampleList) -> RecSampleList:
93+
inputs = inputs[:3]
6094
if hasattr(self, 'neck') and self.neck is not None:
6195
inputs = self.neck(inputs)
6296
pred_instances = [ds.pred_instances for ds in data_samples]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .assigner import L1DistanceAssigner
3+
from .sampler import ConcatSampler, OnlyGTSampler
4+
5+
__all__ = ['L1DistanceAssigner', 'ConcatSampler', 'OnlyGTSampler']
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from mmocr.registry import TASK_UTILS
5+
6+
7+
@TASK_UTILS.register_module()
8+
class L1DistanceAssigner:
9+
10+
def assign(self, gt_instances, pred_instances):
11+
gt_beziers = gt_instances.beziers
12+
pred_beziers = pred_instances.beziers
13+
assign_index = [
14+
int(
15+
torch.argmin(
16+
torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
17+
for i in range(len(pred_beziers))
18+
]
19+
pred_instances.assign_index = assign_index
20+
return gt_instances, pred_instances
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
from mmengine.structures import InstanceData
4+
5+
from mmocr.registry import TASK_UTILS
6+
7+
8+
@TASK_UTILS.register_module()
9+
class ConcatSampler:
10+
11+
def sample(self, gt_instances, pred_instances):
12+
if len(pred_instances) == 0:
13+
return gt_instances
14+
proposals = InstanceData()
15+
proposals.texts = gt_instances.texts + gt_instances[
16+
pred_instances.assign_index].texts
17+
proposals.beziers = torch.cat(
18+
[gt_instances.beziers, pred_instances.beziers], dim=0)
19+
return proposals
20+
21+
22+
@TASK_UTILS.register_module()
23+
class OnlyGTSampler:
24+
25+
def sample(self, gt_instances, pred_instances):
26+
return gt_instances[~gt_instances.ignored]

projects/ABCNet/config/_base_/default_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
default_scope = 'mmocr'
22
env_cfg = dict(
3-
cudnn_benchmark=True,
3+
cudnn_benchmark=False,
44
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
55
dist_cfg=dict(backend='nccl'),
66
)

0 commit comments

Comments
 (0)