11# -*- encoding: utf-8 -*-
22# @Author: SWHL
334+ import re
45from typing import List , Tuple
56
67import numpy as np
@@ -299,7 +300,7 @@ def extract_boxes(self, predictions):
299300
300301
301302class DocLayoutPostProcess :
302- def __init__ (self , labels : List [str ], conf_thres = 0.7 , iou_thres = 0.5 ):
303+ def __init__ (self , labels : List [str ], conf_thres = 0.2 , iou_thres = 0.5 ):
303304 self .labels = labels
304305 self .conf_threshold = conf_thres
305306 self .iou_threshold = iou_thres
@@ -308,40 +309,67 @@ def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5):
308309
309310 def __call__ (
310311 self ,
311- output ,
312+ preds ,
312313 ori_img_shape : Tuple [int , int ],
313314 img_shape : Tuple [int , int ] = (1024 , 1024 ),
314315 ):
315- self .img_height , self .img_width = ori_img_shape
316- self .input_height , self .input_width = img_shape
317-
318- output = output [0 ].squeeze ()
319- boxes = output [:, :- 2 ]
320- confidences = output [:, - 2 ]
321- class_ids = output [:, - 1 ].astype (int )
322-
323- mask = confidences > self .conf_threshold
324- boxes = boxes [mask , :]
325- confidences = confidences [mask ]
326- class_ids = class_ids [mask ]
327-
328- # Rescale boxes to original image dimensions
329- boxes = rescale_boxes (
330- boxes ,
331- self .input_width ,
332- self .input_height ,
333- self .img_width ,
334- self .img_height ,
335- )
316+ preds = preds [0 ]
317+ mask = preds [..., 4 ] > self .conf_threshold
318+ preds = [p [mask [idx ]] for idx , p in enumerate (preds )][0 ]
319+ preds [:, :4 ] = scale_boxes (list (img_shape ), preds [:, :4 ], list (ori_img_shape ))
320+
321+ boxes = preds [:, :4 ]
322+ confidences = preds [:, 4 ]
323+ class_ids = preds [:, 5 ].astype (int )
336324 labels = [self .labels [i ] for i in class_ids ]
337325 return boxes , confidences , labels
338326
339327
340- def rescale_boxes (boxes , input_width , input_height , img_width , img_height ):
341- # Rescale boxes to original image dimensions
342- input_shape = np .array ([input_width , input_height , input_width , input_height ])
343- boxes = np .divide (boxes , input_shape , dtype = np .float32 )
344- boxes *= np .array ([img_width , img_height , img_width , img_height ])
328+ def scale_boxes (
329+ img1_shape , boxes , img0_shape , ratio_pad = None , padding = True , xywh = False
330+ ):
331+ """
332+ Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
333+ specified in (img1_shape) to the shape of a different image (img0_shape).
334+
335+ Args:
336+ img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
337+ boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
338+ img0_shape (tuple): the shape of the target image, in the format of (height, width).
339+ ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
340+ calculated based on the size difference between the two images.
341+ padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
342+ rescaling.
343+ xywh (bool): The box format is xywh or not, default=False.
344+
345+ Returns:
346+ boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
347+ """
348+ if ratio_pad is None : # calculate from img0_shape
349+ gain = min (
350+ img1_shape [0 ] / img0_shape [0 ], img1_shape [1 ] / img0_shape [1 ]
351+ ) # gain = old / new
352+ pad = (
353+ round ((img1_shape [1 ] - img0_shape [1 ] * gain ) / 2 - 0.1 ),
354+ round ((img1_shape [0 ] - img0_shape [0 ] * gain ) / 2 - 0.1 ),
355+ ) # wh padding
356+ else :
357+ gain = ratio_pad [0 ][0 ]
358+ pad = ratio_pad [1 ]
359+
360+ if padding :
361+ boxes [..., 0 ] -= pad [0 ] # x padding
362+ boxes [..., 1 ] -= pad [1 ] # y padding
363+ if not xywh :
364+ boxes [..., 2 ] -= pad [0 ] # x padding
365+ boxes [..., 3 ] -= pad [1 ] # y padding
366+ boxes [..., :4 ] /= gain
367+ return clip_boxes (boxes , img0_shape )
368+
369+
370+ def clip_boxes (boxes , shape ):
371+ boxes [..., [0 , 2 ]] = boxes [..., [0 , 2 ]].clip (0 , shape [1 ]) # x1, x2
372+ boxes [..., [1 , 3 ]] = boxes [..., [1 , 3 ]].clip (0 , shape [0 ]) # y1, y2
345373 return boxes
346374
347375
0 commit comments