Skip to content

Commit 5d7a4b8

Browse files
authored
Merge pull request #4418 from myhloli/dev
Dev
2 parents c732563 + 263eb3f commit 5d7a4b8

File tree

3 files changed

+66
-16
lines changed

3 files changed

+66
-16
lines changed

mineru/backend/hybrid/hybrid_magic_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,8 @@ def get_objects():
424424
# 调用通用方法
425425
return tie_up_category_by_index(
426426
get_subjects,
427-
get_objects
427+
get_objects,
428+
object_block_type=object_block_type
428429
)
429430

430431

mineru/backend/vlm/vlm_magic_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def get_objects():
349349
# 调用通用方法
350350
return tie_up_category_by_index(
351351
get_subjects,
352-
get_objects
352+
get_objects,
353+
object_block_type=object_block_type
353354
)
354355

355356

mineru/utils/magic_model_utils.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
包含两个MagicModel类中重复使用的方法和逻辑
33
"""
44
from typing import List, Dict, Any, Callable
5+
6+
from loguru import logger
57
from mineru.utils.boxbase import bbox_distance, bbox_center_distance, is_in
68

79

@@ -172,11 +174,15 @@ def tie_up_category_by_index(
172174
get_subjects_func: Callable,
173175
get_objects_func: Callable,
174176
extract_subject_func: Callable = None,
175-
extract_object_func: Callable = None
177+
extract_object_func: Callable = None,
178+
object_block_type: str = "object",
176179
):
177180
"""
178181
基于index的类别关联方法,用于将主体对象与客体对象进行关联
179-
客体优先匹配给index最接近的主体,index差值相同时使用bbox中心点距离作为tiebreaker
182+
客体优先匹配给index最接近的主体,匹配优先级为:
183+
1. index差值(最高优先级)
184+
2. bbox边缘距离(相邻边距离)
185+
3. bbox中心点距离(最低优先级,作为最终tiebreaker)
180186
181187
参数:
182188
get_subjects_func: 函数,提取主体对象
@@ -207,6 +213,29 @@ def tie_up_category_by_index(
207213
"sub_idx": i,
208214
}
209215

216+
# 提取所有客体的index集合,用于计算有效index差值
217+
object_indices = set(obj["index"] for obj in objects)
218+
219+
def calc_effective_index_diff(obj_index: int, sub_index: int) -> int:
220+
"""
221+
计算有效的index差值
222+
有效差值 = 绝对差值 - 区间内其他客体的数量
223+
即:如果obj_index和sub_index之间的差值是由其他客体造成的,则应该扣除这部分差值
224+
"""
225+
if obj_index == sub_index:
226+
return 0
227+
228+
start, end = min(obj_index, sub_index), max(obj_index, sub_index)
229+
abs_diff = end - start
230+
231+
# 计算区间(start, end)内有多少个其他客体的index
232+
other_objects_count = 0
233+
for idx in range(start + 1, end):
234+
if idx in object_indices:
235+
other_objects_count += 1
236+
237+
return abs_diff - other_objects_count
238+
210239
# 为每个客体找到最匹配的主体
211240
for obj in objects:
212241
if len(subjects) == 0:
@@ -217,29 +246,48 @@ def tie_up_category_by_index(
217246
min_index_diff = float("inf")
218247
best_subject_indices = []
219248

220-
# 找出index差值最小的所有主体
249+
# 找出有效index差值最小的所有主体
221250
for i, subject in enumerate(subjects):
222251
sub_index = subject["index"]
223-
index_diff = abs(obj_index - sub_index)
252+
index_diff = calc_effective_index_diff(obj_index, sub_index)
224253

225254
if index_diff < min_index_diff:
226255
min_index_diff = index_diff
227256
best_subject_indices = [i]
228257
elif index_diff == min_index_diff:
229258
best_subject_indices.append(i)
230259

231-
# 如果有多个主体的index差值相同,使用中心点距离作为tiebreaker
232-
if len(best_subject_indices) > 1:
233-
min_center_dist = float("inf")
260+
if len(best_subject_indices) == 1:
234261
best_subject_idx = best_subject_indices[0]
235-
236-
for idx in best_subject_indices:
237-
center_dist = bbox_center_distance(obj["bbox"], subjects[idx]["bbox"])
238-
if center_dist < min_center_dist:
239-
min_center_dist = center_dist
240-
best_subject_idx = idx
262+
# 如果有多个主体的index差值相同(最多两个),根据边缘距离进行筛选
263+
elif len(best_subject_indices) == 2:
264+
# 计算所有候选主体的边缘距离
265+
edge_distances = [(idx, bbox_distance(obj["bbox"], subjects[idx]["bbox"])) for idx in best_subject_indices]
266+
edge_dist_diff = abs(edge_distances[0][1] - edge_distances[1][1])
267+
268+
for idx, edge_dist in edge_distances:
269+
logger.debug(f"Obj index: {obj_index}, Sub index: {subjects[idx]['index']}, Edge distance: {edge_dist}")
270+
271+
if edge_dist_diff > 2:
272+
# 边缘距离差值大于2,匹配边缘距离更小的主体
273+
best_subject_idx = min(edge_distances, key=lambda x: x[1])[0]
274+
logger.debug(f"Obj index: {obj_index}, edge_dist_diff > 2, matching to subject with min edge distance, index: {subjects[best_subject_idx]['index']}")
275+
elif object_block_type == "table_caption":
276+
# 边缘距离差值<=2且为table_caption,匹配index更大的主体
277+
best_subject_idx = max(best_subject_indices, key=lambda idx: subjects[idx]["index"])
278+
logger.debug(f"Obj index: {obj_index}, edge_dist_diff <= 2 and table_caption, matching to later subject with index: {subjects[best_subject_idx]['index']}")
279+
elif object_block_type.endswith("footnote"):
280+
# 边缘距离差值<=2且为footnote,匹配index更小的主体
281+
best_subject_idx = min(best_subject_indices, key=lambda idx: subjects[idx]["index"])
282+
logger.debug(f"Obj index: {obj_index}, edge_dist_diff <= 2 and footnote, matching to earlier subject with index: {subjects[best_subject_idx]['index']}")
283+
else:
284+
# 边缘距离差值<=2 且不适用特殊匹配规则,使用中心点距离匹配
285+
center_distances = [(idx, bbox_center_distance(obj["bbox"], subjects[idx]["bbox"])) for idx in best_subject_indices]
286+
for idx, center_dist in center_distances:
287+
logger.debug(f"Obj index: {obj_index}, Sub index: {subjects[idx]['index']}, Center distance: {center_dist}")
288+
best_subject_idx = min(center_distances, key=lambda x: x[1])[0]
241289
else:
242-
best_subject_idx = best_subject_indices[0]
290+
raise ValueError("More than two subjects have the same minimal index difference, which is unexpected.")
243291

244292
# 将客体添加到最佳主体的obj_bboxes中
245293
result_dict[best_subject_idx]["obj_bboxes"].append(extract_object_func(obj))

0 commit comments

Comments
 (0)