Skip to content

Commit 3791783

Browse files
committed
Add table confidence model
Signed-off-by: Alina Buzachis <[email protected]>
1 parent e5cd702 commit 3791783

File tree

4 files changed

+1018
-96
lines changed

4 files changed

+1018
-96
lines changed

docling/datamodel/base_models.py

Lines changed: 136 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,59 @@ class ContainerElement(
200200
pass
201201

202202

203+
# Create a type alias for score values
204+
ScoreValue = float
205+
206+
207+
class TableConfidenceScores(BaseModel):
208+
"""Holds the individual confidence scores for a single table."""
209+
structure_score: ScoreValue = np.nan
210+
cell_text_score: ScoreValue = np.nan
211+
completeness_score: ScoreValue = np.nan
212+
layout_score: ScoreValue = np.nan
213+
214+
@computed_field
215+
@property
216+
def total_table_score(self) -> ScoreValue:
217+
"""
218+
Calculates the weighted average of the individual confidence scores to produce a single total score.
219+
220+
The weights are:
221+
- **Structure Score**: `0.3`
222+
- **Cell Text Score**: `0.3`
223+
- **Completeness Score**: `0.2`
224+
- **Layout Score**: `0.2
225+
226+
These weights are designed to give a balanced, all-purpose score. Data integrity metrics
227+
(**structure** and **text**) are weighted more heavily, as they are often the most critical
228+
for data extraction.
229+
230+
Returns:
231+
ScoreValue: Weighted average score for the table.
232+
"""
233+
scores = [self.structure_score, self.cell_text_score, self.completeness_score, self.layout_score]
234+
235+
weights = [0.3, 0.3, 0.2, 0.2]
236+
valid_scores_and_weights = [(s, w) for s, w in zip(scores, weights) if not math.isnan(s)]
237+
238+
if not valid_scores_and_weights:
239+
return np.nan
240+
241+
valid_scores = [s for s, w in valid_scores_and_weights]
242+
valid_weights = [w for s, w in valid_scores_and_weights]
243+
244+
normalized_weights = [w / sum(valid_weights) for w in valid_weights]
245+
246+
return ScoreValue(np.average(valid_scores, weights=normalized_weights))
247+
248+
203249
class Table(BasePageElement):
204250
otsl_seq: List[str]
205251
num_rows: int = 0
206252
num_cols: int = 0
207253
table_cells: List[TableCell]
208-
254+
detailed_scores: Optional[TableConfidenceScores] = None
255+
209256

210257
class TableStructurePrediction(BaseModel):
211258
table_map: Dict[int, Table] = {}
@@ -242,12 +289,99 @@ class EquationPrediction(BaseModel):
242289
equation_map: Dict[int, TextElement] = {}
243290

244291

292+
class QualityGrade(str, Enum):
293+
POOR = "poor"
294+
FAIR = "fair"
295+
GOOD = "good"
296+
EXCELLENT = "excellent"
297+
UNSPECIFIED = "unspecified"
298+
299+
300+
class PageConfidenceScores(BaseModel):
301+
parse_score: ScoreValue = np.nan
302+
layout_score: ScoreValue = np.nan
303+
ocr_score: ScoreValue = np.nan
304+
tables: Dict[int, TableConfidenceScores] = Field(default_factory=dict)
305+
306+
@computed_field # type: ignore
307+
@property
308+
def table_score(self) -> ScoreValue:
309+
if not self.tables:
310+
return np.nan
311+
return ScoreValue(np.nanmean([t.total_table_score for t in self.tables.values()]))
312+
313+
def _score_to_grade(self, score: ScoreValue) -> QualityGrade:
314+
if score < 0.5:
315+
return QualityGrade.POOR
316+
elif score < 0.8:
317+
return QualityGrade.FAIR
318+
elif score < 0.9:
319+
return QualityGrade.GOOD
320+
elif score >= 0.9:
321+
return QualityGrade.EXCELLENT
322+
323+
return QualityGrade.UNSPECIFIED
324+
325+
@computed_field # type: ignore
326+
@property
327+
def mean_grade(self) -> QualityGrade:
328+
return self._score_to_grade(self.mean_score)
329+
330+
@computed_field # type: ignore
331+
@property
332+
def low_grade(self) -> QualityGrade:
333+
return self._score_to_grade(self.low_score)
334+
335+
@computed_field # type: ignore
336+
@property
337+
def mean_score(self) -> ScoreValue:
338+
return ScoreValue(
339+
np.nanmean(
340+
[
341+
self.ocr_score,
342+
self.table_score,
343+
self.layout_score,
344+
self.parse_score,
345+
]
346+
)
347+
)
348+
349+
@computed_field # type: ignore
350+
@property
351+
def low_score(self) -> ScoreValue:
352+
return ScoreValue(
353+
np.nanquantile(
354+
[
355+
self.ocr_score,
356+
self.table_score,
357+
self.layout_score,
358+
self.parse_score,
359+
],
360+
q=0.05,
361+
)
362+
)
363+
364+
class ConfidenceReport(BaseModel):
365+
pages: Dict[int, PageConfidenceScores] = Field(
366+
default_factory=lambda: defaultdict(PageConfidenceScores)
367+
)
368+
# The document-level scores are no longer properties, they are fields
369+
# that the pipeline will set from the aggregated page scores.
370+
mean_score: ScoreValue = np.nan
371+
low_score: ScoreValue = np.nan
372+
ocr_score: ScoreValue = np.nan
373+
table_score: ScoreValue = np.nan
374+
layout_score: ScoreValue = np.nan
375+
parse_score: ScoreValue = np.nan
376+
377+
245378
class PagePredictions(BaseModel):
246379
layout: Optional[LayoutPrediction] = None
247380
tablestructure: Optional[TableStructurePrediction] = None
248381
figures_classification: Optional[FigureClassificationPrediction] = None
249382
equations_prediction: Optional[EquationPrediction] = None
250383
vlm_response: Optional[VlmPrediction] = None
384+
confidence_scores: PageConfidenceScores = Field(default_factory=PageConfidenceScores)
251385

252386

253387
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
@@ -273,7 +407,7 @@ class Page(BaseModel):
273407
# page_hash: Optional[str] = None
274408
size: Optional[Size] = None
275409
parsed_page: Optional[SegmentedPdfPage] = None
276-
predictions: PagePredictions = PagePredictions()
410+
predictions: PagePredictions = Field(default_factory=PagePredictions)
277411
assembled: Optional[AssembledUnit] = None
278412

279413
_backend: Optional["PdfPageBackend"] = (
@@ -357,97 +491,3 @@ class OpenAiApiResponse(BaseModel):
357491
choices: List[OpenAiResponseChoice]
358492
created: int
359493
usage: OpenAiResponseUsage
360-
361-
362-
# Create a type alias for score values
363-
ScoreValue = float
364-
365-
366-
class QualityGrade(str, Enum):
367-
POOR = "poor"
368-
FAIR = "fair"
369-
GOOD = "good"
370-
EXCELLENT = "excellent"
371-
UNSPECIFIED = "unspecified"
372-
373-
374-
class PageConfidenceScores(BaseModel):
375-
parse_score: ScoreValue = np.nan
376-
layout_score: ScoreValue = np.nan
377-
table_score: ScoreValue = np.nan
378-
ocr_score: ScoreValue = np.nan
379-
380-
def _score_to_grade(self, score: ScoreValue) -> QualityGrade:
381-
if score < 0.5:
382-
return QualityGrade.POOR
383-
elif score < 0.8:
384-
return QualityGrade.FAIR
385-
elif score < 0.9:
386-
return QualityGrade.GOOD
387-
elif score >= 0.9:
388-
return QualityGrade.EXCELLENT
389-
390-
return QualityGrade.UNSPECIFIED
391-
392-
@computed_field # type: ignore
393-
@property
394-
def mean_grade(self) -> QualityGrade:
395-
return self._score_to_grade(self.mean_score)
396-
397-
@computed_field # type: ignore
398-
@property
399-
def low_grade(self) -> QualityGrade:
400-
return self._score_to_grade(self.low_score)
401-
402-
@computed_field # type: ignore
403-
@property
404-
def mean_score(self) -> ScoreValue:
405-
return ScoreValue(
406-
np.nanmean(
407-
[
408-
self.ocr_score,
409-
self.table_score,
410-
self.layout_score,
411-
self.parse_score,
412-
]
413-
)
414-
)
415-
416-
@computed_field # type: ignore
417-
@property
418-
def low_score(self) -> ScoreValue:
419-
return ScoreValue(
420-
np.nanquantile(
421-
[
422-
self.ocr_score,
423-
self.table_score,
424-
self.layout_score,
425-
self.parse_score,
426-
],
427-
q=0.05,
428-
)
429-
)
430-
431-
432-
class ConfidenceReport(PageConfidenceScores):
433-
pages: Dict[int, PageConfidenceScores] = Field(
434-
default_factory=lambda: defaultdict(PageConfidenceScores)
435-
)
436-
437-
@computed_field # type: ignore
438-
@property
439-
def mean_score(self) -> ScoreValue:
440-
return ScoreValue(
441-
np.nanmean(
442-
[c.mean_score for c in self.pages.values()],
443-
)
444-
)
445-
446-
@computed_field # type: ignore
447-
@property
448-
def low_score(self) -> ScoreValue:
449-
return ScoreValue(
450-
np.nanmean(
451-
[c.low_score for c in self.pages.values()],
452-
)
453-
)

0 commit comments

Comments
 (0)