Skip to content

Commit 74e7b3e

Browse files
authored
fix: Validation fixes for list item impurity check (#169)
* deal with false positives in GroupConsecutiveReadingOrderRule list impurity check Signed-off-by: Christoph Auer <[email protected]> * fix fillable_field label Signed-off-by: Christoph Auer <[email protected]> --------- Signed-off-by: Christoph Auer <[email protected]>
1 parent 8be2e83 commit 74e7b3e

File tree

2 files changed

+154
-63
lines changed

2 files changed

+154
-63
lines changed

docling_eval/cvat_tools/parser.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
logger = logging.getLogger("docling_eval.cvat_tools.")
1919

20+
MANUAL_LABEL_MAP: dict[str, DocItemLabel] = {
21+
"fillable_field": DocItemLabel.EMPTY_VALUE,
22+
}
23+
2024

2125
def cvat_box_to_bbox(xtl: float, ytl: float, xbr: float, ybr: float) -> BoundingBox:
2226
"""Convert CVAT box coordinates to BoundingBox (TOPLEFT origin)."""
@@ -129,11 +133,14 @@ def _parse_image_element(
129133

130134
# Parse into one of the known enums; skip if unknown
131135
label_obj: Optional[object] = None
132-
try:
133-
label_obj = DocItemLabel(label_str)
134-
except ValueError:
135-
# Handle common CVAT label variations (e.g., "document Index" -> "document_index")
136-
normalized_label = label_str.lower().replace(" ", "_")
136+
normalized_label = label_str.lower().replace(" ", "_")
137+
138+
manual_label = MANUAL_LABEL_MAP.get(normalized_label)
139+
if manual_label is None:
140+
manual_label = MANUAL_LABEL_MAP.get(label_str)
141+
if manual_label is not None:
142+
label_obj = manual_label
143+
else:
137144
try:
138145
label_obj = DocItemLabel(normalized_label)
139146
except ValueError:

docling_eval/cvat_tools/validator.py

Lines changed: 142 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,77 +1135,161 @@ def validate(self, doc: DocumentStructure) -> List[CVATValidationError]:
11351135

11361136
id_to_element = {el.id: el for el in doc.elements}
11371137

1138-
# Check each group path
1138+
list_group_elements: Dict[int, Set[int]] = {}
1139+
11391140
for path_id, group_element_ids in doc.path_mappings.group.items():
11401141
if len(group_element_ids) < 2:
11411142
continue
11421143

1143-
# Only check groups containing list_item elements
1144-
group_elements_raw = [id_to_element.get(eid) for eid in group_element_ids]
1145-
group_elements: List[CVATElement] = [
1146-
el for el in group_elements_raw if el is not None
1147-
]
1144+
resolved_ids = {eid for eid in group_element_ids if eid in id_to_element}
1145+
1146+
if not resolved_ids:
1147+
continue
1148+
1149+
if any(
1150+
id_to_element[eid].label == DocItemLabel.LIST_ITEM
1151+
for eid in resolved_ids
1152+
):
1153+
list_group_elements[path_id] = resolved_ids
1154+
1155+
if not list_group_elements:
1156+
return errors
11481157

1149-
has_list_items = any(
1150-
el.label == DocItemLabel.LIST_ITEM for el in group_elements
1158+
element_to_groups: Dict[int, Set[int]] = {}
1159+
for path_id, element_ids in list_group_elements.items():
1160+
for element_id in element_ids:
1161+
element_to_groups.setdefault(element_id, set()).add(path_id)
1162+
1163+
group_positions: Dict[int, Dict[int, List[int]]] = {}
1164+
ro_to_groups: Dict[int, Set[int]] = {}
1165+
1166+
for ro_path_id, ro_element_ids in doc.path_mappings.reading_order.items():
1167+
for index, element_id in enumerate(ro_element_ids):
1168+
for group_path_id in element_to_groups.get(element_id, ()):
1169+
positions = group_positions.setdefault(
1170+
group_path_id, {}
1171+
).setdefault(ro_path_id, [])
1172+
positions.append(index)
1173+
ro_to_groups.setdefault(ro_path_id, set()).add(group_path_id)
1174+
1175+
def _ranges_overlap(positions_a: List[int], positions_b: List[int]) -> bool:
1176+
if not positions_a or not positions_b:
1177+
return False
1178+
return min(positions_a) <= max(positions_b) and min(positions_b) <= max(
1179+
positions_a
11511180
)
11521181

1153-
if not has_list_items:
1182+
adjacency: Dict[int, Set[int]] = {
1183+
path_id: set() for path_id in list_group_elements.keys()
1184+
}
1185+
1186+
for ro_path_id, groups_in_ro in ro_to_groups.items():
1187+
group_list = list(groups_in_ro)
1188+
for idx_a in range(len(group_list)):
1189+
path_a = group_list[idx_a]
1190+
positions_a = group_positions.get(path_a, {}).get(ro_path_id, [])
1191+
1192+
for idx_b in range(idx_a + 1, len(group_list)):
1193+
path_b = group_list[idx_b]
1194+
positions_b = group_positions.get(path_b, {}).get(ro_path_id, [])
1195+
1196+
if _ranges_overlap(positions_a, positions_b):
1197+
adjacency[path_a].add(path_b)
1198+
adjacency[path_b].add(path_a)
1199+
1200+
unresolved_paths = set(list_group_elements.keys())
1201+
clusters: List[Tuple[Set[int], Set[int]]] = []
1202+
1203+
while unresolved_paths:
1204+
start_path = unresolved_paths.pop()
1205+
stack: List[int] = [start_path]
1206+
cluster_paths: Set[int] = set()
1207+
cluster_elements: Set[int] = set()
1208+
1209+
while stack:
1210+
current_path = stack.pop()
1211+
if current_path in cluster_paths:
1212+
continue
1213+
1214+
cluster_paths.add(current_path)
1215+
cluster_elements.update(list_group_elements[current_path])
1216+
1217+
for neighbor in adjacency[current_path]:
1218+
if neighbor in unresolved_paths:
1219+
stack.append(neighbor)
1220+
unresolved_paths.remove(neighbor)
1221+
1222+
if not cluster_paths:
1223+
continue
1224+
1225+
clusters.append((cluster_paths, cluster_elements))
1226+
1227+
for cluster_paths, cluster_elements in clusters:
1228+
if len(cluster_elements) < 2:
11541229
continue
11551230

1156-
# Get the reading order for this group
1157-
# We need to check all reading order paths that touch these elements
1231+
sorted_cluster_paths = sorted(cluster_paths)
1232+
cluster_content_layers: Set[ContentLayer] = {
1233+
id_to_element[element_id].content_layer
1234+
for element_id in cluster_elements
1235+
if element_id in id_to_element
1236+
}
1237+
11581238
for ro_path_id, ro_element_ids in doc.path_mappings.reading_order.items():
1159-
# Find positions of grouped elements in this reading order
1160-
group_positions = []
1161-
for i, elem_id in enumerate(ro_element_ids):
1162-
if elem_id in group_element_ids:
1163-
group_positions.append((i, elem_id))
1164-
1165-
if len(group_positions) < 2:
1166-
# Group elements not in this reading order, or only one element
1239+
cluster_positions: List[Tuple[int, int]] = [
1240+
(index, element_id)
1241+
for index, element_id in enumerate(ro_element_ids)
1242+
if element_id in cluster_elements
1243+
]
1244+
1245+
if len(cluster_positions) < 2:
11671246
continue
11681247

1169-
# Check elements between consecutive group elements
1170-
group_positions.sort(key=lambda x: x[0]) # Sort by position
1171-
1172-
for j in range(len(group_positions) - 1):
1173-
start_pos, start_id = group_positions[j]
1174-
end_pos, end_id = group_positions[j + 1]
1175-
1176-
# Check if there are elements between start and end
1177-
sandwiched_elements = []
1178-
for pos in range(start_pos + 1, end_pos):
1179-
between_id = ro_element_ids[pos]
1180-
# If this element is not in the group, it's sandwiched
1181-
if between_id not in group_element_ids:
1182-
sandwiched_elements.append(between_id)
1183-
1184-
if sandwiched_elements:
1185-
# Get element details for the error message
1186-
sandwiched_labels = [
1187-
(
1188-
id_to_element[eid].label.value
1189-
if id_to_element.get(eid)
1190-
else "unknown"
1191-
)
1192-
for eid in sandwiched_elements
1193-
]
1194-
1195-
errors.append(
1196-
CVATValidationError(
1197-
error_type="list_group_reading_order_impurity",
1198-
message=(
1199-
f"Group path {path_id} (list group): Found {len(sandwiched_elements)} "
1200-
f"non-grouped element(s) in reading order between grouped list items "
1201-
f"{start_id} and {end_id}. Sandwiched elements: {sandwiched_elements} "
1202-
f"(labels: {sandwiched_labels}). These elements may not be properly "
1203-
f"nested in the converted document structure."
1204-
),
1205-
severity=ValidationSeverity.WARNING,
1206-
path_id=path_id,
1207-
)
1248+
cluster_positions.sort(key=lambda item: item[0])
1249+
1250+
for idx in range(len(cluster_positions) - 1):
1251+
start_pos, start_id = cluster_positions[idx]
1252+
end_pos, end_id = cluster_positions[idx + 1]
1253+
1254+
sandwiched_elements = [
1255+
ro_element_ids[position]
1256+
for position in range(start_pos + 1, end_pos)
1257+
if ro_element_ids[position] not in cluster_elements
1258+
# Ignore elements that belong to a different content layer (e.g. furniture)
1259+
and (
1260+
ro_element_ids[position] not in id_to_element
1261+
or id_to_element[ro_element_ids[position]].content_layer
1262+
in cluster_content_layers
1263+
)
1264+
]
1265+
1266+
if not sandwiched_elements:
1267+
continue
1268+
1269+
sandwiched_labels = [
1270+
(
1271+
id_to_element[element_id].label.value
1272+
if element_id in id_to_element
1273+
else "unknown"
1274+
)
1275+
for element_id in sandwiched_elements
1276+
]
1277+
1278+
errors.append(
1279+
CVATValidationError(
1280+
error_type="list_group_reading_order_impurity",
1281+
message=(
1282+
f"Group cluster {sorted_cluster_paths} (list groups): Found "
1283+
f"{len(sandwiched_elements)} non-grouped element(s) in reading order "
1284+
f"path {ro_path_id} between grouped list elements {start_id} and {end_id}. "
1285+
f"Sandwiched elements: {sandwiched_elements} (labels: {sandwiched_labels}). "
1286+
"These elements may not be properly nested in the converted document structure."
1287+
),
1288+
severity=ValidationSeverity.WARNING,
1289+
path_id=sorted_cluster_paths[0],
1290+
path_ids=sorted_cluster_paths,
12081291
)
1292+
)
12091293

12101294
return errors
12111295

0 commit comments

Comments
 (0)