Skip to content

Commit cd10859

Browse files
Saidgurbuzcau-git
andauthored
fix: Sort key-value cells by ids (#55)
* add sort_cell_ids method Signed-off-by: Saidgurbuz <[email protected]> * add logic to sort cell ids Signed-off-by: Saidgurbuz <[email protected]> * fix styling Signed-off-by: Saidgurbuz <[email protected]> --------- Signed-off-by: Saidgurbuz <[email protected]> Co-authored-by: Christoph Auer <[email protected]>
1 parent 6efba9b commit cd10859

File tree

4 files changed

+34
-5
lines changed

4 files changed

+34
-5
lines changed

docling_eval/dataset_builders/doclaynet_v2_builder.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
extract_images,
3636
from_pil_to_base64uri,
3737
get_binhash,
38+
sort_cell_ids,
3839
)
3940

4041
# Get logger
@@ -365,23 +366,31 @@ def populate_key_value_item(
365366
doc: DoclingDocument to update
366367
kv_pairs: List of key-value pair dictionaries
367368
"""
368-
cells = []
369+
cell_by_id: Dict[int, GraphCell] = {}
369370
links = []
370371

371372
for pair in kv_pairs:
372373
key_data = pair["key"]
373374
value_data = pair["value"]
374375

375-
key_cell = self.create_graph_cell(key_data, GraphCellLabel.KEY)
376-
value_cell = self.create_graph_cell(value_data, GraphCellLabel.VALUE)
376+
if cell_by_id.get(key_data["cell_id"], None) is None:
377+
key_cell = self.create_graph_cell(key_data, GraphCellLabel.KEY)
378+
cell_by_id[key_data["cell_id"]] = key_cell
379+
else:
380+
key_cell = cell_by_id[key_data["cell_id"]]
377381

378-
cells.append(key_cell)
379-
cells.append(value_cell)
382+
if cell_by_id.get(value_data["cell_id"], None) is None:
383+
value_cell = self.create_graph_cell(value_data, GraphCellLabel.VALUE)
384+
cell_by_id[value_data["cell_id"]] = value_cell
385+
else:
386+
value_cell = cell_by_id[value_data["cell_id"]]
380387

381388
# link between key and value
382389
kv_link = self.create_graph_link(key_cell, value_cell)
383390
links.append(kv_link)
384391

392+
cells = list(cell_by_id.values())
393+
385394
overall_bbox = self.get_overall_bbox(
386395
links, cell_dict={cell.cell_id: cell for cell in cells}
387396
)
@@ -403,6 +412,9 @@ def populate_key_value_item(
403412
# Add the key_value_item to the document.
404413
doc.add_key_values(graph=graph, prov=prov)
405414

415+
# sort the cell ids in the graph
416+
sort_cell_ids(doc)
417+
406418
# The minimal fix for DocLayNetV2Builder is to add type annotation to link_pairs:
407419

408420
def create_kv_pairs(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:

docling_eval/dataset_builders/funsd_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
extract_images,
2222
from_pil_to_base64uri,
2323
get_binhash,
24+
sort_cell_ids,
2425
)
2526

2627
# Get logger
@@ -280,6 +281,8 @@ def populate_key_value_item(
280281

281282
doc.add_key_values(graph=graph, prov=prov)
282283

284+
sort_cell_ids(doc)
285+
283286
return doc
284287

285288
def iterate(self) -> Iterable[DatasetRecord]:

docling_eval/dataset_builders/xfund_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
extract_images,
2323
from_pil_to_base64uri,
2424
get_binhash,
25+
sort_cell_ids,
2526
)
2627

2728
# Get logger
@@ -286,6 +287,8 @@ def populate_key_value_item(
286287

287288
doc.add_key_values(graph=graph, prov=prov)
288289

290+
sort_cell_ids(doc)
291+
289292
return doc
290293

291294
def iterate(self) -> Iterable[DatasetRecord]:

docling_eval/utils/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,14 @@ def classify_cells(graph: GraphData) -> None:
538538
else:
539539
# fallback case.
540540
cell.label = GraphCellLabel.UNSPECIFIED
541+
542+
543+
def sort_cell_ids(doc: DoclingDocument) -> None:
544+
mapping = {}
545+
for i, item in enumerate(doc.key_value_items[0].graph.cells):
546+
mapping[item.cell_id] = i
547+
for i, item in enumerate(doc.key_value_items[0].graph.cells):
548+
item.cell_id = mapping[item.cell_id]
549+
for i, link in enumerate(doc.key_value_items[0].graph.links):
550+
link.source_cell_id = mapping[link.source_cell_id]
551+
link.target_cell_id = mapping[link.target_cell_id]

0 commit comments

Comments
 (0)