22包含两个MagicModel类中重复使用的方法和逻辑
33"""
44from typing import List , Dict , Any , Callable
5+
6+ from loguru import logger
57from mineru .utils .boxbase import bbox_distance , bbox_center_distance , is_in
68
79
@@ -172,11 +174,15 @@ def tie_up_category_by_index(
172174 get_subjects_func : Callable ,
173175 get_objects_func : Callable ,
174176 extract_subject_func : Callable = None ,
175- extract_object_func : Callable = None
177+ extract_object_func : Callable = None ,
178+ object_block_type : str = "object" ,
176179):
177180 """
178181 基于index的类别关联方法,用于将主体对象与客体对象进行关联
179- 客体优先匹配给index最接近的主体,index差值相同时使用bbox中心点距离作为tiebreaker
182+ 客体优先匹配给index最接近的主体,匹配优先级为:
183+ 1. index差值(最高优先级)
184+ 2. bbox边缘距离(相邻边距离)
185+ 3. bbox中心点距离(最低优先级,作为最终tiebreaker)
180186
181187 参数:
182188 get_subjects_func: 函数,提取主体对象
@@ -207,6 +213,29 @@ def tie_up_category_by_index(
207213 "sub_idx" : i ,
208214 }
209215
216+ # 提取所有客体的index集合,用于计算有效index差值
217+ object_indices = set (obj ["index" ] for obj in objects )
218+
219+ def calc_effective_index_diff (obj_index : int , sub_index : int ) -> int :
220+ """
221+ 计算有效的index差值
222+ 有效差值 = 绝对差值 - 区间内其他客体的数量
223+ 即:如果obj_index和sub_index之间的差值是由其他客体造成的,则应该扣除这部分差值
224+ """
225+ if obj_index == sub_index :
226+ return 0
227+
228+ start , end = min (obj_index , sub_index ), max (obj_index , sub_index )
229+ abs_diff = end - start
230+
231+ # 计算区间(start, end)内有多少个其他客体的index
232+ other_objects_count = 0
233+ for idx in range (start + 1 , end ):
234+ if idx in object_indices :
235+ other_objects_count += 1
236+
237+ return abs_diff - other_objects_count
238+
210239 # 为每个客体找到最匹配的主体
211240 for obj in objects :
212241 if len (subjects ) == 0 :
@@ -217,29 +246,48 @@ def tie_up_category_by_index(
217246 min_index_diff = float ("inf" )
218247 best_subject_indices = []
219248
220- # 找出index差值最小的所有主体
249+ # 找出有效index差值最小的所有主体
221250 for i , subject in enumerate (subjects ):
222251 sub_index = subject ["index" ]
223- index_diff = abs (obj_index - sub_index )
252+ index_diff = calc_effective_index_diff (obj_index , sub_index )
224253
225254 if index_diff < min_index_diff :
226255 min_index_diff = index_diff
227256 best_subject_indices = [i ]
228257 elif index_diff == min_index_diff :
229258 best_subject_indices .append (i )
230259
231- # 如果有多个主体的index差值相同,使用中心点距离作为tiebreaker
232- if len (best_subject_indices ) > 1 :
233- min_center_dist = float ("inf" )
260+ if len (best_subject_indices ) == 1 :
234261 best_subject_idx = best_subject_indices [0 ]
235-
236- for idx in best_subject_indices :
237- center_dist = bbox_center_distance (obj ["bbox" ], subjects [idx ]["bbox" ])
238- if center_dist < min_center_dist :
239- min_center_dist = center_dist
240- best_subject_idx = idx
262+ # 如果有多个主体的index差值相同(最多两个),根据边缘距离进行筛选
263+ elif len (best_subject_indices ) == 2 :
264+ # 计算所有候选主体的边缘距离
265+ edge_distances = [(idx , bbox_distance (obj ["bbox" ], subjects [idx ]["bbox" ])) for idx in best_subject_indices ]
266+ edge_dist_diff = abs (edge_distances [0 ][1 ] - edge_distances [1 ][1 ])
267+
268+ for idx , edge_dist in edge_distances :
269+ logger .debug (f"Obj index: { obj_index } , Sub index: { subjects [idx ]['index' ]} , Edge distance: { edge_dist } " )
270+
271+ if edge_dist_diff > 2 :
272+ # 边缘距离差值大于2,匹配边缘距离更小的主体
273+ best_subject_idx = min (edge_distances , key = lambda x : x [1 ])[0 ]
274+ logger .debug (f"Obj index: { obj_index } , edge_dist_diff > 2, matching to subject with min edge distance, index: { subjects [best_subject_idx ]['index' ]} " )
275+ elif object_block_type == "table_caption" :
276+ # 边缘距离差值<=2且为table_caption,匹配index更大的主体
277+ best_subject_idx = max (best_subject_indices , key = lambda idx : subjects [idx ]["index" ])
278+ logger .debug (f"Obj index: { obj_index } , edge_dist_diff <= 2 and table_caption, matching to later subject with index: { subjects [best_subject_idx ]['index' ]} " )
279+ elif object_block_type .endswith ("footnote" ):
280+ # 边缘距离差值<=2且为footnote,匹配index更小的主体
281+ best_subject_idx = min (best_subject_indices , key = lambda idx : subjects [idx ]["index" ])
282+ logger .debug (f"Obj index: { obj_index } , edge_dist_diff <= 2 and footnote, matching to earlier subject with index: { subjects [best_subject_idx ]['index' ]} " )
283+ else :
284+ # 边缘距离差值<=2 且不适用特殊匹配规则,使用中心点距离匹配
285+ center_distances = [(idx , bbox_center_distance (obj ["bbox" ], subjects [idx ]["bbox" ])) for idx in best_subject_indices ]
286+ for idx , center_dist in center_distances :
287+ logger .debug (f"Obj index: { obj_index } , Sub index: { subjects [idx ]['index' ]} , Center distance: { center_dist } " )
288+ best_subject_idx = min (center_distances , key = lambda x : x [1 ])[0 ]
241289 else :
242- best_subject_idx = best_subject_indices [ 0 ]
290+ raise ValueError ( "More than two subjects have the same minimal index difference, which is unexpected." )
243291
244292 # 将客体添加到最佳主体的obj_bboxes中
245293 result_dict [best_subject_idx ]["obj_bboxes" ].append (extract_object_func (obj ))
0 commit comments