11# Copyright (c) OpenMMLab. All rights reserved.
2- from typing import Tuple
2+ from typing import Optional , Sequence , Tuple
33
44from mmengine .structures import LabelData
55from torch import Tensor
@@ -15,14 +15,19 @@ class RecRoIHead(BaseRoIHead):
1515 """Simplest base roi head including one bbox head and one mask head."""
1616
1717 def __init__ (self ,
18- neck = None ,
18+ inputs_indices : Optional [Sequence ] = None ,
19+ neck : OptMultiConfig = None ,
20+ assigner : OptMultiConfig = None ,
1921 sampler : OptMultiConfig = None ,
2022 roi_extractor : OptMultiConfig = None ,
2123 rec_head : OptMultiConfig = None ,
2224 init_cfg = None ):
2325 super ().__init__ (init_cfg )
24- if sampler is not None :
25- self .sampler = TASK_UTILS .build (sampler )
26+ self .inputs_indices = inputs_indices
27+ self .assigner = assigner
28+ if assigner is not None :
29+ self .assigner = TASK_UTILS .build (assigner )
30+ self .sampler = TASK_UTILS .build (sampler )
2631 if neck is not None :
2732 self .neck = MODELS .build (neck )
2833 self .roi_extractor = MODELS .build (roi_extractor )
@@ -43,11 +48,39 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
4348 Returns:
4449 dict[str, Tensor]: A dictionary of loss components
4550 """
46- proposals = [
47- ds .gt_instances [~ ds .gt_instances .ignored ] for ds in data_samples
48- ]
51+
52+ if self .inputs_indices is not None :
53+ inputs = [inputs [i ] for i in self .inputs_indices ]
54+ # proposals = [
55+ # ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
56+ # ]
57+ proposals = list ()
58+ for ds in data_samples :
59+ pred_instances = ds .pred_instances
60+ gt_instances = ds .gt_instances
61+ # # assign
62+ # gt_beziers = gt_instances.beziers
63+ # pred_beziers = pred_instances.beziers
64+ # assign_index = [
65+ # int(
66+ # torch.argmin(
67+ # torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
68+ # for i in range(len(pred_beziers))
69+ # ]
70+ # proposal = InstanceData()
71+ # proposal.texts = gt_instances.texts + gt_instances[
72+ # assign_index].texts
73+ # proposal.beziers = torch.cat(
74+ # [gt_instances.beziers, pred_instances.beziers], dim=0)
75+ if self .assigner :
76+ gt_instances , pred_instances = self .assigner .assign (
77+ gt_instances , pred_instances )
78+ proposal = self .sampler .sample (gt_instances , pred_instances )
79+ proposals .append (proposal )
4980
5081 proposals = [p for p in proposals if len (p ) > 0 ]
82+ if hasattr (self , 'neck' ) and self .neck is not None :
83+ inputs = self .neck (inputs )
5184 bbox_feats = self .roi_extractor (inputs , proposals )
5285 rec_data_samples = [
5386 TextRecogDataSample (gt_text = LabelData (item = text ))
@@ -57,6 +90,7 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
5790
5891 def predict (self , inputs : Tuple [Tensor ],
5992 data_samples : DetSampleList ) -> RecSampleList :
93+ inputs = inputs [:3 ]
6094 if hasattr (self , 'neck' ) and self .neck is not None :
6195 inputs = self .neck (inputs )
6296 pred_instances = [ds .pred_instances for ds in data_samples ]
0 commit comments