Skip to content

Commit 2fe3c41

Browse files
committed
abcnetv2 train
1 parent bf41194 commit 2fe3c41

File tree

14 files changed

+203
-30
lines changed

14 files changed

+203
-30
lines changed

projects/ABCNet/abcnet/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .bifpn import BiFPN
1313
from .coordinate_head import CoordinateHead
1414
from .rec_roi_head import RecRoIHead
15+
from .task_utils import * # noqa: F401,F403
1516

1617
__all__ = [
1718
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',

projects/ABCNet/abcnet/model/abcnet_det_head.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class ABCNetDetHead(BaseTextDetHead):
1414

1515
def __init__(self,
1616
in_channels,
17-
module_loss=dict(type='ABCNetLoss'),
17+
module_loss=dict(type='ABCNetDetModuleLoss'),
1818
postprocessor=dict(type='ABCNetDetPostprocessor'),
1919
num_classes=1,
2020
strides=(4, 8, 16, 32, 64),
@@ -181,8 +181,8 @@ def forward_single(self, x, scale, stride):
181181
# float to avoid overflow when enabling FP16
182182
if self.use_scale:
183183
bbox_pred = scale(bbox_pred).float()
184-
else:
185-
bbox_pred = bbox_pred.float()
184+
# else:
185+
# bbox_pred = bbox_pred.float()
186186
if self.norm_on_bbox:
187187
# bbox_pred needed for gradient computation has been modified
188188
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace

projects/ABCNet/abcnet/model/abcnet_det_module_loss.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Dict, List, Tuple
33

44
import torch
5+
import torch.nn.functional as F
56
from mmdet.models.task_modules.prior_generators import MlvlPointGenerator
67
from mmdet.models.utils import multi_apply
78
from mmdet.utils import reduce_mean
@@ -149,11 +150,17 @@ def forward(self, inputs: Tuple[Tensor],
149150
avg_factor=centerness_denorm)
150151
loss_centerness = self.loss_centerness(
151152
pos_centerness, pos_centerness_targets, avg_factor=num_pos)
152-
loss_bezier = self.loss_bezier(
153-
pos_bezier_preds,
154-
pos_bezier_targets,
155-
weight=pos_centerness_targets[:, None],
156-
avg_factor=centerness_denorm)
153+
# loss_bezier = self.loss_bezier(
154+
# pos_bezier_preds,
155+
# pos_bezier_targets,
156+
# weight=pos_centerness_targets[:, None],
157+
# avg_factor=centerness_denorm)
158+
159+
loss_bezier = F.smooth_l1_loss(
160+
pos_bezier_preds, pos_bezier_targets, reduction='none')
161+
loss_bezier = (
162+
(loss_bezier.mean(dim=-1) * pos_centerness_targets).sum() /
163+
centerness_denorm)
157164
else:
158165
loss_bbox = pos_bbox_preds.sum()
159166
loss_centerness = pos_centerness.sum()
@@ -250,6 +257,7 @@ def _get_targets_single(self, data_sample: TextDetDataSample,
250257
polygons = gt_instances.polygons
251258
beziers = gt_bboxes.new([poly2bezier(poly) for poly in polygons])
252259
gt_instances.beziers = beziers
260+
# beziers = gt_instances.beziers
253261
if num_gts == 0:
254262
return gt_labels.new_full((num_points,), self.num_classes), \
255263
gt_bboxes.new_zeros((num_points, 4)), \

projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __call__(self, pred_results, data_samples, training: bool = False):
216216
Returns:
217217
list[TextDetDataSample]: Batch of post-processed datasamples.
218218
"""
219-
if training:
220-
return data_samples
219+
# if training:
220+
# return data_samples
221221
cfg = self.train_cfg if training else self.test_cfg
222222
if cfg is None:
223223
cfg = {}

projects/ABCNet/abcnet/model/bezier_roi_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, feats: Tuple[Tensor],
8888
# convert fp32 to fp16 when amp is on
8989
rois = rois.type_as(feats[0])
9090
out_size = self.roi_layers[0].output_size
91-
feats = feats[:3]
91+
# feats = feats[:3]
9292
num_levels = len(feats)
9393
roi_feats = feats[0].new_zeros(
9494
rois.size(0), self.out_channels, *out_size)

projects/ABCNet/abcnet/model/bifpn.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ def __init__(self,
170170
self.bifpn_convs = nn.ModuleList()
171171
# weighted
172172
self.weight_two_nodes = nn.Parameter(
173-
torch.Tensor(2, levels).fill_(init))
173+
torch.Tensor(2, levels).fill_(init), requires_grad=True)
174+
174175
self.weight_three_nodes = nn.Parameter(
175-
torch.Tensor(3, levels - 2).fill_(init))
176-
self.relu = nn.ReLU()
176+
torch.Tensor(3, levels - 2).fill_(init), requires_grad=True)
177+
178+
# self.relu = nn.ReLU(inplace=False)
177179
for _ in range(2):
178180
for _ in range(self.levels - 1): # 1,2,3
179181
fpn_conv = nn.Sequential(
@@ -193,9 +195,10 @@ def forward(self, inputs):
193195
# build top-down and down-top path with stack
194196
levels = self.levels
195197
# w relu
196-
w1 = self.relu(self.weight_two_nodes)
197-
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
198-
w2 = self.relu(self.weight_three_nodes)
198+
199+
_w1 = F.relu(self.weight_two_nodes)
200+
w1 = _w1 / (torch.sum(_w1, dim=0) + self.eps) # normalize
201+
w2 = F.relu(self.weight_three_nodes)
199202
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
200203
# build top-down
201204
idx_bifpn = 0

projects/ABCNet/abcnet/model/rec_roi_head.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Tuple
2+
from typing import Optional, Sequence, Tuple
33

44
from mmengine.structures import LabelData
55
from torch import Tensor
@@ -15,14 +15,19 @@ class RecRoIHead(BaseRoIHead):
1515
"""Simplest base roi head including one bbox head and one mask head."""
1616

1717
def __init__(self,
18-
neck=None,
18+
inputs_indices: Optional[Sequence] = None,
19+
neck: OptMultiConfig = None,
20+
assigner: OptMultiConfig = None,
1921
sampler: OptMultiConfig = None,
2022
roi_extractor: OptMultiConfig = None,
2123
rec_head: OptMultiConfig = None,
2224
init_cfg=None):
2325
super().__init__(init_cfg)
24-
if sampler is not None:
25-
self.sampler = TASK_UTILS.build(sampler)
26+
self.inputs_indices = inputs_indices
27+
self.assigner = assigner
28+
if assigner is not None:
29+
self.assigner = TASK_UTILS.build(assigner)
30+
self.sampler = TASK_UTILS.build(sampler)
2631
if neck is not None:
2732
self.neck = MODELS.build(neck)
2833
self.roi_extractor = MODELS.build(roi_extractor)
@@ -43,11 +48,39 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
4348
Returns:
4449
dict[str, Tensor]: A dictionary of loss components
4550
"""
46-
proposals = [
47-
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
48-
]
51+
52+
if self.inputs_indices is not None:
53+
inputs = [inputs[i] for i in self.inputs_indices]
54+
# proposals = [
55+
# ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
56+
# ]
57+
proposals = list()
58+
for ds in data_samples:
59+
pred_instances = ds.pred_instances
60+
gt_instances = ds.gt_instances
61+
# # assign
62+
# gt_beziers = gt_instances.beziers
63+
# pred_beziers = pred_instances.beziers
64+
# assign_index = [
65+
# int(
66+
# torch.argmin(
67+
# torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
68+
# for i in range(len(pred_beziers))
69+
# ]
70+
# proposal = InstanceData()
71+
# proposal.texts = gt_instances.texts + gt_instances[
72+
# assign_index].texts
73+
# proposal.beziers = torch.cat(
74+
# [gt_instances.beziers, pred_instances.beziers], dim=0)
75+
if self.assigner:
76+
gt_instances, pred_instances = self.assigner.assign(
77+
gt_instances, pred_instances)
78+
proposal = self.sampler.sample(gt_instances, pred_instances)
79+
proposals.append(proposal)
4980

5081
proposals = [p for p in proposals if len(p) > 0]
82+
if hasattr(self, 'neck') and self.neck is not None:
83+
inputs = self.neck(inputs)
5184
bbox_feats = self.roi_extractor(inputs, proposals)
5285
rec_data_samples = [
5386
TextRecogDataSample(gt_text=LabelData(item=text))
@@ -57,6 +90,7 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
5790

5891
def predict(self, inputs: Tuple[Tensor],
5992
data_samples: DetSampleList) -> RecSampleList:
93+
inputs = inputs[:3]
6094
if hasattr(self, 'neck') and self.neck is not None:
6195
inputs = self.neck(inputs)
6296
pred_instances = [ds.pred_instances for ds in data_samples]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .assigner import L1DistanceAssigner
3+
from .sampler import ConcatSampler, OnlyGTSampler
4+
5+
__all__ = ['L1DistanceAssigner', 'ConcatSampler', 'OnlyGTSampler']
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from mmocr.registry import TASK_UTILS
5+
6+
7+
@TASK_UTILS.register_module()
8+
class L1DistanceAssigner:
9+
10+
def assign(self, gt_instances, pred_instances):
11+
gt_beziers = gt_instances.beziers
12+
pred_beziers = pred_instances.beziers
13+
assign_index = [
14+
int(
15+
torch.argmin(
16+
torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
17+
for i in range(len(pred_beziers))
18+
]
19+
pred_instances.assign_index = assign_index
20+
return gt_instances, pred_instances
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
from mmengine.structures import InstanceData
4+
5+
from mmocr.registry import TASK_UTILS
6+
7+
8+
@TASK_UTILS.register_module()
9+
class ConcatSampler:
10+
11+
def sample(self, gt_instances, pred_instances):
12+
if len(pred_instances) == 0:
13+
return gt_instances
14+
proposals = InstanceData()
15+
proposals.texts = gt_instances.texts + gt_instances[
16+
pred_instances.assign_index].texts
17+
proposals.beziers = torch.cat(
18+
[gt_instances.beziers, pred_instances.beziers], dim=0)
19+
return proposals
20+
21+
22+
@TASK_UTILS.register_module()
23+
class OnlyGTSampler:
24+
25+
def sample(self, gt_instances, pred_instances):
26+
return gt_instances[~gt_instances.ignored]

0 commit comments

Comments
 (0)