@@ -31,13 +31,12 @@ def __init__(self, model_path: Optional[str] = None):
3131
3232 self .session = OrtInferSession (model_path )
3333
34- def __call__ (self , img : np .ndarray ) -> Optional [np .ndarray ]:
34+ def __call__ (self , img : np .ndarray , ** kwargs ) -> Optional [np .ndarray ]:
3535 img_info = self .preprocess (img )
3636 pred = self .infer (img_info )
37- polygons = self .postprocess (img , pred )
37+ polygons = self .postprocess (img , pred , ** kwargs )
3838 if polygons .size == 0 :
3939 return None
40-
4140 polygons = polygons .reshape (polygons .shape [0 ], 4 , 2 )
4241 polygons [:, 3 , :], polygons [:, 1 , :] = (
4342 polygons [:, 1 , :].copy (),
@@ -68,7 +67,25 @@ def infer(self, input):
6867 result = result [0 ].astype (np .uint8 )
6968 return result
7069
71- def postprocess (self , img , pred , row = 50 , col = 30 , alph = 15 , angle = 50 ):
70+ def postprocess (self , img , pred , ** kwargs ):
71+ row = kwargs .get ("row" , 50 ) if kwargs else 50
72+ col = kwargs .get ("col" , 30 ) if kwargs else 30
73+ h_lines_threshold = kwargs .get ("h_lines_threshold" , 100 ) if kwargs else 100
74+ v_lines_threshold = kwargs .get ("v_lines_threshold" , 15 ) if kwargs else 15
75+ angle = kwargs .get ("angle" , 50 ) if kwargs else 50
76+ morph_close = (
77+ kwargs .get ("morph_close" , True ) if kwargs else True
78+ ) # 是否进行闭合运算以找到更多小的框
79+ more_h_lines = (
80+ kwargs .get ("more_h_lines" , True ) if kwargs else True
81+ ) # 是否调整以找到更多的横线
82+ more_v_lines = (
83+ kwargs .get ("more_v_lines" , True ) if kwargs else True
84+ ) # 是否调整以找到更多的横线
85+ extend_line = (
86+ kwargs .get ("extend_line" , True ) if kwargs else True
87+ ) # 是否进行线段延长使得端点连接
88+
7289 ori_shape = img .shape
7390 pred = np .uint8 (pred )
7491 hpred = copy .deepcopy (pred ) # 横线
@@ -89,16 +106,19 @@ def postprocess(self, img, pred, row=50, col=30, alph=15, angle=50):
89106 vpred = cv2 .morphologyEx (
90107 vpred , cv2 .MORPH_CLOSE , vkernel , iterations = 1
91108 ) # 先膨胀后腐蚀的过程
92- hpred = cv2 .morphologyEx (hpred , cv2 .MORPH_CLOSE , hkernel , iterations = 1 )
109+ if morph_close :
110+ hpred = cv2 .morphologyEx (hpred , cv2 .MORPH_CLOSE , hkernel , iterations = 1 )
93111 colboxes = get_table_line (vpred , axis = 1 , lineW = col ) # 竖线
94112 rowboxes = get_table_line (hpred , axis = 0 , lineW = row ) # 横线
95- # rboxes_row_, rboxes_col_ = adjust_lines(rowboxes, colboxes, alph = alph, angle=angle)
96- rboxes_row_ = adjust_lines (rowboxes , alph = 100 , angle = angle )
97- rboxes_col_ = adjust_lines (colboxes , alph = alph , angle = angle )
113+ rboxes_row_ , rboxes_col_ = [], []
114+ if more_h_lines :
115+ rboxes_row_ = adjust_lines (rowboxes , alph = h_lines_threshold , angle = angle )
116+ if more_v_lines :
117+ rboxes_col_ = adjust_lines (colboxes , alph = v_lines_threshold , angle = angle )
98118 rowboxes += rboxes_row_
99119 colboxes += rboxes_col_
100- rowboxes , colboxes = final_adjust_lines ( rowboxes , colboxes )
101-
120+ if extend_line :
121+ rowboxes , colboxes = final_adjust_lines ( rowboxes , colboxes )
102122 tmp = np .zeros (img .shape [:2 ], dtype = "uint8" )
103123 tmp = draw_lines (tmp , rowboxes + colboxes , color = 255 , lineW = 2 )
104124 labels = measure .label (tmp < 255 , connectivity = 2 ) # 8连通区域标记
0 commit comments