@@ -91,19 +91,13 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
9191 line_width = self .line_width ,
9292 alpha = self .alpha )
9393
94- # draw gt/pred labels
95- if gt_labels is not None and pred_labels is not None :
94+ areas = (bboxes [:, 3 ] - bboxes [:, 1 ]) * (bboxes [:, 2 ] - bboxes [:, 0 ])
95+ scales = _get_adaptive_scales (areas )
96+ positions = (bboxes [:, :2 ] + bboxes [:, 2 :]) // 2
97+
98+ if gt_labels is not None :
9699 gt_tokens_biolabel = gt_labels .item
97100 gt_words_label = []
98- pred_tokens_biolabel = pred_labels .item
99- pred_words_label = []
100-
101- if 'score' in pred_labels :
102- pred_tokens_biolabel_score = pred_labels .score
103- pred_words_label_score = []
104- else :
105- pred_tokens_biolabel_score = None
106- pred_words_label_score = None
107101
108102 pre_word_id = None
109103 for idx , cur_word_id in enumerate (word_ids ):
@@ -112,36 +106,60 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
112106 gt_words_label_name = gt_tokens_biolabel [idx ][2 :] \
113107 if gt_tokens_biolabel [idx ] != 'O' else 'other'
114108 gt_words_label .append (gt_words_label_name )
109+ pre_word_id = cur_word_id
110+ assert len (gt_words_label ) == len (bboxes )
111+ if pred_labels is not None :
112+ pred_tokens_biolabel = pred_labels .item
113+ pred_words_label = []
114+ pred_tokens_biolabel_score = pred_labels .score
115+ pred_words_label_score = []
116+
117+ pre_word_id = None
118+ for idx , cur_word_id in enumerate (word_ids ):
119+ if cur_word_id is not None :
120+ if cur_word_id != pre_word_id :
115121 pred_words_label_name = pred_tokens_biolabel [idx ][2 :] \
116122 if pred_tokens_biolabel [idx ] != 'O' else 'other'
117123 pred_words_label .append (pred_words_label_name )
118- if pred_tokens_biolabel_score is not None :
119- pred_words_label_score .append (
120- pred_tokens_biolabel_score [idx ])
124+ pred_words_label_score .append (
125+ pred_tokens_biolabel_score [idx ])
121126 pre_word_id = cur_word_id
122- assert len (gt_words_label ) == len (bboxes )
123127 assert len (pred_words_label ) == len (bboxes )
124128
125- areas = (bboxes [:, 3 ] - bboxes [:, 1 ]) * (
126- bboxes [:, 2 ] - bboxes [:, 0 ])
127- scales = _get_adaptive_scales (areas )
128- positions = (bboxes [:, :2 ] + bboxes [:, 2 :]) // 2
129-
129+ # draw gt or pred labels
130+ if gt_labels is not None and pred_labels is not None :
130131 for i , (pos , gt , pred ) in enumerate (
131132 zip (positions , gt_words_label , pred_words_label )):
132- if pred_words_label_score is not None :
133- score = round (float (pred_words_label_score [i ]) * 100 , 1 )
134- label_text = f'{ gt } | { pred } ({ score } )'
135- else :
136- label_text = f'{ gt } | { pred } '
137-
133+ score = round (float (pred_words_label_score [i ]) * 100 , 1 )
134+ label_text = f'{ gt } | { pred } ({ score } )'
138135 self .draw_texts (
139136 label_text ,
140137 pos ,
141138 colors = self .label_color if gt == pred else 'r' ,
142139 font_sizes = int (13 * scales [i ]),
143140 vertical_alignments = 'center' ,
144141 horizontal_alignments = 'center' )
142+ elif pred_labels is not None :
143+ for i , (pos , pred ) in enumerate (zip (positions , pred_words_label )):
144+ score = round (float (pred_words_label_score [i ]) * 100 , 1 )
145+ label_text = f'Pred: { pred } ({ score } )'
146+ self .draw_texts (
147+ label_text ,
148+ pos ,
149+ colors = self .label_color ,
150+ font_sizes = int (13 * scales [i ]),
151+ vertical_alignments = 'center' ,
152+ horizontal_alignments = 'center' )
153+ elif gt_labels is not None :
154+ for i , (pos , gt ) in enumerate (zip (positions , gt_words_label )):
155+ label_text = f'GT: { gt } '
156+ self .draw_texts (
157+ label_text ,
158+ pos ,
159+ colors = self .label_color ,
160+ font_sizes = int (13 * scales [i ]),
161+ vertical_alignments = 'center' ,
162+ horizontal_alignments = 'center' )
145163
146164 return self .get_image ()
147165
0 commit comments