Skip to content

Commit efbdd0e

Browse files
feat: Parallelize the table evaluation
Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>
1 parent 88051e4 commit efbdd0e

File tree

2 files changed

+172
-84
lines changed

2 files changed

+172
-84
lines changed

docling_eval/evaluators/base_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def supported_prediction_formats(self) -> List[PredictionFormats]:
118118
def save_intermediate_evaluations(
119119
self,
120120
evaluation_name: str,
121-
enunumerate_id: int,
121+
enumerate_id: int,
122122
doc_id: str,
123123
evaluations: List[UnitEvaluationType],
124124
) -> Optional[Path]:
@@ -131,7 +131,7 @@ def save_intermediate_evaluations(
131131
return None
132132

133133
evals = [ev.model_dump() for ev in evaluations]
134-
evaluation_filename = f"{evaluation_name}_{enunumerate_id:05d}_{doc_id}.json"
134+
evaluation_filename = f"{evaluation_name}_{enumerate_id:05d}_{doc_id}.json"
135135
evaluation_fn = self._intermediate_evaluations_path / evaluation_filename # type: ignore
136136
_log.info("Saving intermediate evaluations: %s", evaluation_fn)
137137
with open(evaluation_fn, "w") as fd:

docling_eval/evaluators/table_evaluator.py

Lines changed: 170 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import glob
22
import logging
33
import random
4+
from concurrent.futures import Executor, Future, ProcessPoolExecutor, as_completed
45
from pathlib import Path
56
from typing import Dict, List, Optional
67

@@ -90,7 +91,7 @@ def save_histogram_delta_row_col(self, figname: Path):
9091
plt.ylabel("%")
9192
plt.legend(loc="upper right")
9293

93-
logging.info(f"saving figure to {figname}")
94+
_log.info(f"saving figure to {figname}")
9495
plt.savefig(figname)
9596

9697

@@ -104,6 +105,62 @@ def is_complex_table(table: TableItem) -> bool:
104105
return False
105106

106107

108+
def evaluate_tables(
109+
teds_scorer,
110+
stopwords: list[str],
111+
doc_id: str,
112+
table_id: int,
113+
true_html: str,
114+
true_num_rows: int,
115+
true_num_cols: int,
116+
pred_html: str,
117+
pred_num_rows: int,
118+
pred_num_cols: int,
119+
is_complex: bool,
120+
structure_only: bool,
121+
# ) -> tuple[float, bool, bool]:
122+
) -> TableEvaluation:
123+
r"""
124+
Execution function
125+
Receive 2 tables as html-formatted string. Compute the TEDS score
126+
127+
Return
128+
------
129+
teds: float
130+
is_complex: bool
131+
structure_only: bool
132+
"""
133+
# TODO: Check if exceptions can be thrown in the following code
134+
for stopword in stopwords:
135+
pred_html = pred_html.replace(stopword, "")
136+
for stopword in stopwords:
137+
true_html = true_html.replace(stopword, "")
138+
139+
pred_html_obj = html.fromstring(pred_html)
140+
true_html_obj = html.fromstring(true_html)
141+
142+
teds = teds_scorer(
143+
gt_table=true_html_obj,
144+
pred_table=pred_html_obj,
145+
structure_only=structure_only,
146+
)
147+
teds = round(teds, 3)
148+
149+
# Prepare output
150+
table_evaluation = TableEvaluation(
151+
TEDS=teds,
152+
is_complex=is_complex,
153+
filename=doc_id,
154+
table_id=table_id,
155+
true_ncols=true_num_cols,
156+
pred_ncols=pred_num_cols,
157+
true_nrows=true_num_rows,
158+
pred_nrows=pred_num_rows,
159+
structure_only_evaluation=structure_only,
160+
)
161+
return table_evaluation
162+
163+
107164
class TableEvaluator(BaseEvaluator):
108165
r"""
109166
Evaluate table predictions from HF dataset with the columns:
@@ -114,6 +171,7 @@ def __init__(
114171
intermediate_evaluations_path: Optional[Path] = None,
115172
structure_only: bool = False,
116173
prediction_sources: List[PredictionFormats] = [],
174+
concurrency: int = 4,
117175
):
118176
supported_prediction_formats: List[PredictionFormats] = [
119177
PredictionFormats.DOCLING_DOCUMENT,
@@ -122,6 +180,7 @@ def __init__(
122180
if not prediction_sources:
123181
prediction_sources = supported_prediction_formats
124182
super().__init__(
183+
concurrency=concurrency,
125184
intermediate_evaluations_path=intermediate_evaluations_path,
126185
prediction_sources=prediction_sources,
127186
supported_prediction_formats=supported_prediction_formats,
@@ -142,7 +201,7 @@ def __call__(
142201
"GTDoclingDocument"
143202
"PredictionDoclingDocument"
144203
"""
145-
logging.info("Loading the split '%s' from: '%s'", split, ds_path)
204+
_log.info("Loading the split '%s' from: '%s'", split, ds_path)
146205

147206
ext_docdoc_loader: Optional[ExternalDoclingDocumentLoader] = None
148207
if external_predictions_path is not None:
@@ -151,9 +210,9 @@ def __call__(
151210
# Load the dataset
152211
split_path = str(ds_path / split / "*.parquet")
153212
split_files = glob.glob(split_path)
154-
logging.info("Files: %s", split_files)
213+
_log.info("Files: %s", split_files)
155214
ds = load_dataset("parquet", data_files={split: split_files})
156-
logging.info("Overview of dataset: %s", ds)
215+
_log.info("Overview of dataset: %s", ds)
157216

158217
# Select the split
159218
ds_selection: Dataset = ds[split]
@@ -163,54 +222,78 @@ def __call__(
163222
rejected_samples: Dict[EvaluationRejectionType, int] = {
164223
EvaluationRejectionType.MISSING_PREDICTION: 0,
165224
EvaluationRejectionType.EVALUATION_ERROR: 0,
225+
EvaluationRejectionType.MISMATHCED_DOCUMENT: 0,
166226
}
167227

168-
for i, data in tqdm(
169-
enumerate(ds_selection),
170-
desc="Table evaluations",
171-
ncols=120,
172-
total=len(ds_selection),
173-
):
174-
data_record = DatasetRecordWithPrediction.model_validate(data)
175-
doc_id = data_record.doc_id
176-
gt_doc = data_record.ground_truth_doc
177-
pred_doc = self._get_pred_doc(data_record, ext_docdoc_loader)
178-
if not pred_doc:
179-
_log.error("There is no prediction for doc_id=%s", doc_id)
180-
rejected_samples[EvaluationRejectionType.MISSING_PREDICTION] += 1
181-
continue
228+
with ProcessPoolExecutor(max_workers=self._concurrency) as executor:
229+
futures: list[Future] = []
230+
table_futures: list[Future]
231+
table_rejection: Optional[EvaluationRejectionType]
232+
233+
# Submit pages for execution
234+
_log.info("Submitting the tables for evaluation...")
235+
for i, data in enumerate(ds_selection):
236+
data_record = DatasetRecordWithPrediction.model_validate(data)
237+
doc_id = data_record.doc_id
238+
gt_doc = data_record.ground_truth_doc
239+
pred_doc = self._get_pred_doc(data_record, ext_docdoc_loader)
240+
if not pred_doc:
241+
_log.error("There is no prediction for doc_id=%s", doc_id)
242+
rejected_samples[EvaluationRejectionType.MISSING_PREDICTION] += 1
243+
continue
182244

183-
try:
184245
if not self._structure_only:
185-
results = self._evaluate_tables_in_documents(
246+
# Evaluate the tables with structure + content
247+
table_futures, table_rejection = self._evaluate_tables_in_documents(
248+
executor,
186249
doc_id=doc_id,
187250
true_doc=gt_doc,
188251
pred_doc=pred_doc,
189252
structure_only=False,
190253
)
191-
table_evaluations.extend(results)
192-
193-
if self._intermediate_evaluations_path:
194-
self.save_intermediate_evaluations(
195-
"TEDs_struct_content", i, doc_id, results
196-
)
197-
198-
results = self._evaluate_tables_in_documents(
199-
doc_id=data[BenchMarkColumns.DOC_ID],
254+
if table_rejection != None:
255+
rejected_samples[table_rejection] += 1
256+
continue
257+
futures.extend(table_futures)
258+
259+
# Always evaluate the tables with structure
260+
table_futures, table_rejection = self._evaluate_tables_in_documents(
261+
executor,
262+
doc_id=doc_id,
200263
true_doc=gt_doc,
201264
pred_doc=pred_doc,
202265
structure_only=True,
203266
)
204-
table_struct_evaluations.extend(results)
267+
if table_rejection != None:
268+
rejected_samples[table_rejection] += 1
269+
continue
270+
futures.extend(table_futures)
271+
272+
# Collect the futures
273+
_log.info("Collecting the tables for evaluations...")
274+
for future in tqdm(
275+
as_completed(futures),
276+
desc="Table evaluations",
277+
ncols=120,
278+
total=len(ds_selection),
279+
):
280+
table_evaluation: TableEvaluation = future.result()
281+
table_id: int = table_evaluation.table_id
282+
doc_id = table_evaluation.filename
283+
284+
if not table_evaluation.structure_only_evaluation:
285+
table_evaluations.append(table_evaluation)
286+
if self._intermediate_evaluations_path:
287+
self.save_intermediate_evaluations(
288+
"TEDs_struct_content", table_id, doc_id, [table_evaluation]
289+
)
290+
291+
table_struct_evaluations.append(table_evaluation)
205292
if self._intermediate_evaluations_path:
206293
self.save_intermediate_evaluations(
207-
"TEDs_struct", i, doc_id, results
294+
"TEDs_struct", table_id, doc_id, [table_evaluation]
208295
)
209296

210-
except Exception as ex:
211-
rejected_samples[EvaluationRejectionType.EVALUATION_ERROR] += 1
212-
_log.error("Error during tables evaluation for %s", doc_id)
213-
214297
_log.info(
215298
"Finish. %s documents were skipped due to evaluation errors",
216299
rejected_samples[EvaluationRejectionType.EVALUATION_ERROR],
@@ -247,28 +330,45 @@ def __call__(
247330

248331
def _evaluate_tables_in_documents(
249332
self,
333+
executor: Executor,
250334
doc_id: str,
251335
true_doc: DoclingDocument,
252336
pred_doc: DoclingDocument,
253337
structure_only: bool = False,
254-
) -> List[TableEvaluation]:
255-
r""" """
256-
table_evaluations = []
257-
true_tables = true_doc.tables
258-
pred_tables = pred_doc.tables
259-
_log.info(
260-
"#-true-tables: %s, #-pred-tables: %s", len(true_tables), len(pred_tables)
338+
) -> tuple[list[Future], Optional[EvaluationRejectionType]]:
339+
r"""
340+
1. Extract the tables from true/pred document
341+
2. Reject if the number of tables differs across true/pred
342+
3. Export table as html-formatted string.
343+
4. Submit the tables for evaluation
344+
5. Return futures (one per table)
345+
346+
Return
347+
------
348+
349+
"""
350+
futures: list[Future] = []
351+
true_tables: list[TableItem] = true_doc.tables
352+
pred_tables: list[TableItem] = pred_doc.tables
353+
true_tables_len = len(true_tables)
354+
pred_tables_len = len(pred_tables)
355+
_log.debug(
356+
"#-true-tables: %s, #-pred-tables: %s", true_tables_len, pred_tables_len
261357
)
262-
assert len(true_tables) == len(
263-
pred_tables
264-
), "len(true_tables)!=len(pred_tables)"
358+
# Reject the document is there is a mismatch in the number of tables between true/pred doc
359+
if true_tables_len != pred_tables_len:
360+
_log.error(
361+
"Mismatched number of tables between GT and predictions: [%d, %d]. Skipping doc: %s",
362+
true_tables_len,
363+
pred_tables_len,
364+
doc_id,
365+
)
366+
return futures, EvaluationRejectionType.MISMATHCED_DOCUMENT
265367

266368
for table_id in range(len(true_tables)): # , len(pred_tables)):
267369
# Avoid items of type DocItemLabel.DOCUMENT_INDEX
268370
if true_tables[table_id].label != DocItemLabel.TABLE:
269-
logging.warning(
270-
f"Skipping table with label {true_tables[table_id].label}"
271-
)
371+
_log.warning(f"Skipping table with label {true_tables[table_id].label}")
272372
continue
273373

274374
try:
@@ -277,44 +377,32 @@ def _evaluate_tables_in_documents(
277377

278378
is_complex = is_complex_table(true_table)
279379

280-
true_html = true_table.export_to_html(true_doc)
281-
pred_html = pred_table.export_to_html(pred_doc)
282-
283-
# Filter out tags that may be present in GT but not in prediction to avoid penalty
284-
for stopword in self._stopwords:
285-
predicted_html = pred_html.replace(stopword, "")
286-
for stopword in self._stopwords:
287-
true_html = true_html.replace(stopword, "")
288-
289-
true_html_obj = html.fromstring(true_html)
290-
pred_html_obj = html.fromstring(pred_html)
291-
292-
teds = self._teds_scorer(
293-
gt_table=true_html_obj,
294-
pred_table=pred_html_obj,
295-
structure_only=structure_only,
296-
)
297-
# logging.info(f"teds: {teds}")
298-
299-
teds = round(teds, 3)
300-
table_evaluation = TableEvaluation(
301-
TEDS=teds,
302-
is_complex=is_complex,
303-
filename=doc_id,
304-
table_id=table_id,
305-
true_ncols=true_table.data.num_cols,
306-
pred_ncols=pred_table.data.num_cols,
307-
true_nrows=true_table.data.num_rows,
308-
pred_nrows=pred_table.data.num_rows,
309-
structure_only_evaluation=structure_only,
380+
true_html: str = true_table.export_to_html(true_doc)
381+
pred_html: str = pred_table.export_to_html(pred_doc)
382+
383+
# Submit table for evaluation
384+
futures.append(
385+
executor.submit(
386+
evaluate_tables,
387+
self._teds_scorer,
388+
self._stopwords,
389+
doc_id,
390+
table_id,
391+
true_html,
392+
true_table.data.num_rows,
393+
true_table.data.num_cols,
394+
pred_html,
395+
pred_table.data.num_rows,
396+
pred_table.data.num_cols,
397+
is_complex,
398+
structure_only,
399+
)
310400
)
311-
table_evaluations.append(table_evaluation)
312401
except Exception:
313-
logging.error(
402+
_log.error(
314403
f"Table {table_id} from document {doc_id} could not be compared!"
315404
)
316-
317-
return table_evaluations
405+
return futures, None
318406

319407
def _get_pred_doc(
320408
self,

0 commit comments

Comments
 (0)