diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py index 11a877a7142..15e8fd4ec67 100644 --- a/ppocr/postprocess/table_postprocess.py +++ b/ppocr/postprocess/table_postprocess.py @@ -129,11 +129,8 @@ def decode_label(self, batch): def _bbox_decode(self, bbox, shape): h, w, ratio_h, ratio_w, pad_h, pad_w = shape - h, w = pad_h, pad_w bbox[0::2] *= w bbox[1::2] *= h - bbox[0::2] /= ratio_w - bbox[1::2] /= ratio_h return bbox @@ -189,3 +186,14 @@ def _bbox_decode(self, bbox, shape): x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 bbox = np.array([x1, y1, x2, y2]) return bbox + + +class SLANetPlusLabelDecode(TableLabelDecode): + def _bbox_decode(self, bbox, shape): + h, w, ratio_h, ratio_w, pad_h, pad_w = shape + h, w = pad_h, pad_w + bbox[0::2] *= w + bbox[1::2] *= h + bbox[0::2] /= ratio_w + bbox[1::2] /= ratio_h + return bbox