Skip to content

Commit d63a439

Browse files
fix: ocr visualization and add ocr recognition metrics (#144)
* fix: ocr visualization Signed-off-by: samiullahchattha <[email protected]> * fix type error Signed-off-by: samiullahchattha <[email protected]> * fix: improve OCR visualizer * fix: build errors * add word and character accuracy metrics Signed-off-by: samiullahchattha <[email protected]> * strip leading or trailing whitespace in edit distance Signed-off-by: samiullahchattha <[email protected]> * fi visualizations Signed-off-by: samiuc <[email protected]> --------- Signed-off-by: samiullahchattha <[email protected]> Signed-off-by: samiuc <[email protected]> Co-authored-by: samiullahchattha <[email protected]>
1 parent 6c117a2 commit d63a439

File tree

6 files changed

+724
-239
lines changed

6 files changed

+724
-239
lines changed

docling_eval/cli/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def evaluate(
603603
json.dump(evaluation.model_dump(), fd, indent=2, sort_keys=True)
604604

605605
elif modality == EvaluationModality.OCR:
606-
ocr_evaluator = OCREvaluator()
606+
ocr_evaluator = OCREvaluator(intermediate_evaluations_path=odir)
607607
evaluation = ocr_evaluator( # type: ignore
608608
idir,
609609
split=split,

docling_eval/evaluators/ocr/benchmark_runner.py

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from docling_core.types.doc.page import SegmentedPage
66

77
from docling_eval.evaluators.ocr.evaluation_models import (
8-
AggregatedBenchmarkMetrics,
98
OcrBenchmarkEntry,
109
OcrMetricsSummary,
1110
Word,
@@ -14,6 +13,7 @@
1413
from docling_eval.evaluators.ocr.processing_utils import (
1514
_CalculationConstants,
1615
_IgnoreZoneFilter,
16+
_IgnoreZoneFilterHWR,
1717
extract_word_from_text_cell,
1818
)
1919

@@ -26,6 +26,7 @@ def __init__(
2626
ignore_zone_filter_type: str = "default",
2727
add_space_for_merged_prediction_words: bool = True,
2828
add_space_for_merged_gt_words: bool = True,
29+
aggregation_mode: str = "union", # "mean" or "union"
2930
) -> None:
3031
self.model_identifier: str = model_identifier
3132
self.add_space_for_merged_prediction_words: bool = (
@@ -39,8 +40,13 @@ def __init__(
3940
] = {}
4041
self.image_to_ignore_zones_map: Dict[str, List[Word]] = {}
4142
self.calculator_type: str = performance_calculator_type
43+
self.aggregation_mode: str = aggregation_mode
4244

43-
self.ignore_zone_filter: _IgnoreZoneFilter = _IgnoreZoneFilter()
45+
self.ignore_zone_filter: "_IgnoreZoneFilter | _IgnoreZoneFilterHWR"
46+
if ignore_zone_filter_type.lower() == "hwr":
47+
self.ignore_zone_filter = _IgnoreZoneFilterHWR()
48+
else:
49+
self.ignore_zone_filter = _IgnoreZoneFilter()
4450

4551
def process_single_page_pair(
4652
self,
@@ -126,6 +132,70 @@ def calculate_aggregated_metrics(
126132
if key not in summed_metrics:
127133
summed_metrics[key] = ""
128134

135+
num_images = len(self.image_metrics_results)
136+
# Recognition aggregation
137+
if self.aggregation_mode == "union":
138+
total_weighted_tp_words: float = summed_metrics.get(
139+
"tp_words_weighted", 0.0
140+
)
141+
total_fp: float = summed_metrics.get(
142+
"number_of_false_positive_detections", 0.0
143+
)
144+
total_fn: float = summed_metrics.get(
145+
"number_of_false_negative_detections", 0.0
146+
)
147+
total_union_words: float = total_weighted_tp_words + total_fp + total_fn
148+
total_perfect_sensitive: float = summed_metrics.get(
149+
"perfect_matches_sensitive_weighted", 0.0
150+
)
151+
total_perfect_insensitive: float = summed_metrics.get(
152+
"perfect_matches_insensitive_weighted", 0.0
153+
)
154+
avg_word_acc_sensitive = total_perfect_sensitive / max(
155+
_CalculationConstants.EPS, total_union_words
156+
)
157+
avg_word_acc_insensitive = total_perfect_insensitive / max(
158+
_CalculationConstants.EPS, total_union_words
159+
)
160+
# Character (union)
161+
sum_ed_sensitive_tp: float = summed_metrics.get("sum_ed_sensitive_tp", 0.0)
162+
sum_ed_insensitive_tp: float = summed_metrics.get(
163+
"sum_ed_insensitive_tp", 0.0
164+
)
165+
sum_max_len_tp: float = summed_metrics.get("sum_max_len_tp", 0.0)
166+
sum_text_len_fp: float = summed_metrics.get("text_len_fp", 0.0)
167+
sum_text_len_fn: float = summed_metrics.get("text_len_fn", 0.0)
168+
total_chars_union: float = (
169+
sum_max_len_tp + sum_text_len_fp + sum_text_len_fn
170+
)
171+
avg_ed_union_sensitive: float = (
172+
sum_ed_sensitive_tp + sum_text_len_fp + sum_text_len_fn
173+
) / max(_CalculationConstants.EPS, total_chars_union)
174+
avg_ed_union_insensitive: float = (
175+
sum_ed_insensitive_tp + sum_text_len_fp + sum_text_len_fn
176+
) / max(_CalculationConstants.EPS, total_chars_union)
177+
avg_char_acc_sensitive = 1 - avg_ed_union_sensitive
178+
avg_char_acc_insensitive = 1 - avg_ed_union_insensitive
179+
# Convert to percentage later
180+
avg_word_acc_sensitive *= 100.0
181+
avg_word_acc_insensitive *= 100.0
182+
avg_char_acc_sensitive *= 100.0
183+
avg_char_acc_insensitive *= 100.0
184+
else:
185+
# Per-image mean of already-percentage metrics
186+
avg_word_acc_sensitive = (
187+
summed_metrics.get("word_accuracy_sensitive", 0.0) / num_images
188+
)
189+
avg_word_acc_insensitive = (
190+
summed_metrics.get("word_accuracy_insensitive", 0.0) / num_images
191+
)
192+
avg_char_acc_sensitive = (
193+
summed_metrics.get("character_accuracy_sensitive", 0.0) / num_images
194+
)
195+
avg_char_acc_insensitive = (
196+
summed_metrics.get("character_accuracy_insensitive", 0.0) / num_images
197+
)
198+
129199
total_true_positives: float = summed_metrics.get(
130200
"number_of_true_positive_matches", _CalculationConstants.EPS
131201
)
@@ -147,28 +217,35 @@ def calculate_aggregated_metrics(
147217
_CalculationConstants.EPS,
148218
)
149219

220+
avg_char_acc_sensitive = (
221+
summed_metrics.get("character_accuracy_sensitive", 0.0) / num_images
222+
)
223+
avg_char_acc_insensitive = (
224+
summed_metrics.get("character_accuracy_insensitive", 0.0) / num_images
225+
)
226+
150227
aggregated_metrics_data = {
151228
"f1": 100 * overall_f1_score,
152229
"recall": 100 * overall_recall,
153230
"precision": 100 * overall_precision,
231+
"word_accuracy_sensitive": avg_word_acc_sensitive,
232+
"word_accuracy_insensitive": avg_word_acc_insensitive,
233+
"character_accuracy_sensitive": avg_char_acc_sensitive,
234+
"character_accuracy_insensitive": avg_char_acc_insensitive,
154235
}
155236

156-
aggregated_metrics = AggregatedBenchmarkMetrics.model_validate(
157-
aggregated_metrics_data
158-
)
159-
output_results = aggregated_metrics.model_dump(by_alias=True)
160-
161-
for key, val in output_results.items():
237+
for key, val in aggregated_metrics_data.items():
162238
try:
163239
formatted_value: float = float(f"{{:.{float_precision}f}}".format(val))
164-
output_results[key] = formatted_value
240+
aggregated_metrics_data[key] = formatted_value
165241
except (ValueError, TypeError):
166242
pass
167-
return output_results
243+
244+
return aggregated_metrics_data
168245

169246
def get_formatted_metrics_summary(
170247
self,
171-
float_precision: int = 1,
248+
float_precision: int = 2,
172249
) -> List[Dict[str, Any]]:
173250
summary_list: List[Dict[str, Any]] = []
174251

docling_eval/evaluators/ocr/evaluation_models.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List, Optional
1+
from typing import Any, Dict, List, Optional, Tuple
22

33
from docling_core.types.doc import BoundingBox
44
from docling_core.types.doc.page import TextCell
@@ -7,6 +7,17 @@
77

88
class _CalculationConstants:
99
EPS: float = 1.0e-6
10+
CHAR_NORMALIZE_MAP: Dict[str, str] = {
11+
"fi": "fi",
12+
"fl": "fl",
13+
"“": '"',
14+
"”": '"',
15+
"‘": "'",
16+
"’": "'",
17+
"—": "-",
18+
"–": "-",
19+
"\xa0": " ",
20+
}
1021

1122

1223
class Word(TextCell):
@@ -15,6 +26,8 @@ class Word(TextCell):
1526
matched: bool = Field(default=False)
1627
ignore_zone: Optional[bool] = None
1728
to_remove: Optional[bool] = None
29+
# number of GT words represented by this Word after merging
30+
word_weight: int = Field(default=1)
1831

1932
@property
2033
def bbox(self) -> BoundingBox:
@@ -42,6 +55,20 @@ class OcrMetricsSummary(BaseModel):
4255
detection_precision: float
4356
detection_recall: float
4457
detection_f1: float
58+
# recognition metrics
59+
word_accuracy_sensitive: float = 0.0
60+
word_accuracy_insensitive: float = 0.0
61+
character_accuracy_sensitive: float = 0.0
62+
character_accuracy_insensitive: float = 0.0
63+
# for dataset-level union aggregation
64+
tp_words_weighted: float = 0.0
65+
perfect_matches_sensitive_weighted: float = 0.0
66+
perfect_matches_insensitive_weighted: float = 0.0
67+
sum_ed_sensitive_tp: float = 0.0
68+
sum_ed_insensitive_tp: float = 0.0
69+
sum_max_len_tp: float = 0.0
70+
text_len_fp: float = 0.0
71+
text_len_fn: float = 0.0
4572

4673
class Config:
4774
populate_by_name = True
@@ -52,15 +79,6 @@ class OcrBenchmarkEntry(BaseModel):
5279
metrics: OcrMetricsSummary
5380

5481

55-
class AggregatedBenchmarkMetrics(BaseModel):
56-
f1: float = Field(alias="F1")
57-
recall: float = Field(alias="Recall")
58-
precision: float = Field(alias="Precision")
59-
60-
class Config:
61-
populate_by_name = True
62-
63-
6482
class DocumentEvaluationEntry(BaseModel):
6583
doc_id: str
6684

@@ -72,3 +90,31 @@ class OcrDatasetEvaluationResult(BaseModel):
7290
f1_score: float = 0.0
7391
recall: float = 0.0
7492
precision: float = 0.0
93+
word_accuracy_sensitive: float = 0.0
94+
word_accuracy_insensitive: float = 0.0
95+
character_accuracy_sensitive: float = 0.0
96+
character_accuracy_insensitive: float = 0.0
97+
98+
99+
class WordEvaluationMetadata(BaseModel):
100+
text: str
101+
confidence: Optional[float] = None
102+
bounding_box: BoundingBox
103+
is_true_positive: bool = False
104+
is_false_positive: bool = False
105+
is_false_negative: bool = False
106+
edit_distance_sensitive: Optional[int] = None
107+
edit_distance_insensitive: Optional[int] = None
108+
109+
110+
class TruePositiveMatch(BaseModel):
111+
pred: WordEvaluationMetadata
112+
gt: WordEvaluationMetadata
113+
114+
115+
class DocumentEvaluationMetadata(BaseModel):
116+
doc_id: str
117+
true_positives: List[TruePositiveMatch]
118+
false_positives: List[WordEvaluationMetadata]
119+
false_negatives: List[WordEvaluationMetadata]
120+
metrics: OcrMetricsSummary

docling_eval/evaluators/ocr/performance_calculator.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from collections import namedtuple
33
from typing import Dict, List, Tuple
44

5+
import edit_distance
6+
import numpy as np
57
from docling_core.types.doc.page import SegmentedPage
68

79
from docling_eval.evaluators.ocr.evaluation_models import (
@@ -17,6 +19,7 @@
1719
refine_prediction_to_many_gt_boxes,
1820
)
1921
from docling_eval.evaluators.ocr.processing_utils import (
22+
calculate_edit_distance,
2023
convert_word_to_text_cell,
2124
merge_words_into_one,
2225
)
@@ -389,6 +392,63 @@ def calculate_image_metrics(self) -> OcrMetricsSummary:
389392
recall + precision, _CalculationConstants.EPS
390393
)
391394

395+
sum_ed_sensitive = _CalculationConstants.EPS
396+
sum_ed_insensitive = _CalculationConstants.EPS
397+
sum_max_len_tp = _CalculationConstants.EPS
398+
perfect_matches_sensitive = 0
399+
perfect_matches_insensitive = 0
400+
total_tp_words_weighted = 0
401+
402+
for gt_word, pred_word in self.confirmed_gt_prediction_matches:
403+
gt_text = gt_word.text
404+
pred_text = pred_word.text
405+
# weight by the number of GT words represented by this merged word
406+
gt_weight = getattr(gt_word, "word_weight", 1)
407+
total_tp_words_weighted += gt_weight
408+
409+
max_len = max(len(gt_text), len(pred_text), 1)
410+
sum_max_len_tp += max_len
411+
412+
# Case-sensitive metrics
413+
ed_sensitive = calculate_edit_distance(gt_text, pred_text, None)
414+
sum_ed_sensitive += ed_sensitive
415+
if ed_sensitive == 0:
416+
perfect_matches_sensitive += gt_weight
417+
418+
# Case-insensitive metrics
419+
ed_insensitive = calculate_edit_distance(
420+
gt_text.upper(), pred_text.upper(), None
421+
)
422+
sum_ed_insensitive += ed_insensitive
423+
if ed_insensitive == 0:
424+
perfect_matches_insensitive += gt_weight
425+
426+
text_len_fp = sum(len(w.text) for w in self.current_false_positives)
427+
text_len_fn = sum(len(w.text) for w in self.current_false_negatives)
428+
429+
# word accuracy (union-based), weighted by GT merges for TPs
430+
total_union_words = (
431+
total_tp_words_weighted + num_false_positives + num_false_negatives
432+
)
433+
word_acc_union_sensitive = perfect_matches_sensitive / max(
434+
_CalculationConstants.EPS, total_union_words
435+
)
436+
word_acc_union_insensitive = perfect_matches_insensitive / max(
437+
_CalculationConstants.EPS, total_union_words
438+
)
439+
440+
# character accuracy (edit score union-based)
441+
total_chars_union = sum_max_len_tp + text_len_fp + text_len_fn
442+
avg_ed_union_sensitive = (sum_ed_sensitive + text_len_fp + text_len_fn) / max(
443+
_CalculationConstants.EPS, total_chars_union
444+
)
445+
avg_ed_union_insensitive = (
446+
sum_ed_insensitive + text_len_fp + text_len_fn
447+
) / max(_CalculationConstants.EPS, total_chars_union)
448+
449+
char_acc_sensitive = 1 - avg_ed_union_sensitive
450+
char_acc_insensitive = 1 - avg_ed_union_insensitive
451+
392452
metrics_summary_data = {
393453
"number_of_prediction_cells": num_prediction_cells_final,
394454
"number_of_gt_cells": num_gt_cells_final,
@@ -398,6 +458,19 @@ def calculate_image_metrics(self) -> OcrMetricsSummary:
398458
"detection_precision": 100.0 * precision,
399459
"detection_recall": 100.0 * recall,
400460
"detection_f1": 100.0 * f1_score,
461+
"word_accuracy_sensitive": 100.0 * word_acc_union_sensitive,
462+
"word_accuracy_insensitive": 100.0 * word_acc_union_insensitive,
463+
"character_accuracy_sensitive": 100.0 * char_acc_sensitive,
464+
"character_accuracy_insensitive": 100.0 * char_acc_insensitive,
465+
# additional counters for dataset-level union aggregation
466+
"tp_words_weighted": float(total_tp_words_weighted),
467+
"perfect_matches_sensitive_weighted": float(perfect_matches_sensitive),
468+
"perfect_matches_insensitive_weighted": float(perfect_matches_insensitive),
469+
"sum_ed_sensitive_tp": float(sum_ed_sensitive),
470+
"sum_ed_insensitive_tp": float(sum_ed_insensitive),
471+
"sum_max_len_tp": float(sum_max_len_tp),
472+
"text_len_fp": float(text_len_fp),
473+
"text_len_fn": float(text_len_fn),
401474
}
402475

403476
summary_instance = OcrMetricsSummary.model_validate(metrics_summary_data)

0 commit comments

Comments
 (0)