|
3 | 3 | import xml.etree.ElementTree as ET |
4 | 4 | from collections import defaultdict |
5 | 5 | from pathlib import Path |
6 | | -from typing import Dict, List, Optional, Union |
| 6 | +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union |
7 | 7 |
|
8 | 8 | import cv2 |
9 | 9 | import numpy as np |
10 | 10 | import torch |
11 | 11 | from PIL import Image as PILImage |
12 | 12 | from transformers import DetrImageProcessor, TableTransformerForObjectDetection |
| 13 | +from transformers.models.table_transformer.modeling_table_transformer import ( |
| 14 | + TableTransformerObjectDetectionOutput, |
| 15 | +) |
13 | 16 |
|
14 | 17 | from unstructured_inference.config import inference_config |
15 | 18 | from unstructured_inference.inference.layoutelement import table_cells_to_dataframe |
@@ -172,18 +175,22 @@ def recognize(outputs: dict, img: PILImage.Image, tokens: list): |
172 | 175 | """Recognize table elements.""" |
173 | 176 | str_class_name2idx = get_class_map("structure") |
174 | 177 | str_class_idx2name = {v: k for k, v in str_class_name2idx.items()} |
175 | | - str_class_thresholds = structure_class_thresholds |
| 178 | + class_thresholds = structure_class_thresholds |
176 | 179 |
|
177 | 180 | # Post-process detected objects, assign class labels |
178 | 181 | objects = outputs_to_objects(outputs, img.size, str_class_idx2name) |
179 | | - |
| 182 | + high_confidence_objects = apply_thresholds_on_objects(objects, class_thresholds) |
180 | 183 | # Further process the detected objects so they correspond to a consistent table |
181 | | - tables_structure = objects_to_structures(objects, tokens, str_class_thresholds) |
| 184 | + tables_structure = objects_to_structures(high_confidence_objects, tokens, class_thresholds) |
182 | 185 | # Enumerate all table cells: grid cells and spanning cells |
183 | 186 | return [structure_to_cells(structure, tokens)[0] for structure in tables_structure] |
184 | 187 |
|
185 | 188 |
|
186 | | -def outputs_to_objects(outputs, img_size, class_idx2name): |
| 189 | +def outputs_to_objects( |
| 190 | + outputs: TableTransformerObjectDetectionOutput, |
| 191 | + img_size: tuple[int, int], |
| 192 | + class_idx2name: Mapping[int, str], |
| 193 | +): |
187 | 194 | """Output table element types.""" |
188 | 195 | m = outputs["logits"].softmax(-1).max(-1) |
189 | 196 | pred_labels = list(m.indices.detach().cpu().numpy())[0] |
@@ -212,6 +219,32 @@ def outputs_to_objects(outputs, img_size, class_idx2name): |
212 | 219 | return objects |
213 | 220 |
|
214 | 221 |
|
| 222 | +def apply_thresholds_on_objects( |
| 223 | + objects: Sequence[Mapping[str, Any]], thresholds: Mapping[str, float] |
| 224 | +) -> Sequence[Mapping[str, Any]]: |
| 225 | + """ |
| 226 | + Filters predicted objects which the confidence scores below the thresholds |
| 227 | +
|
| 228 | + Args: |
| 229 | + objects: Sequence of mappings for example: |
| 230 | + [ |
| 231 | + { |
| 232 | + "label": "table row", |
| 233 | + "score": 0.55, |
| 234 | + "bbox": [...], |
| 235 | + }, |
| 236 | + ..., |
| 237 | + ] |
| 238 | + thresholds: Mapping from labels to thresholds |
| 239 | +
|
| 240 | + Returns: |
| 241 | + Filtered list of objects |
| 242 | +
|
| 243 | + """ |
| 244 | + objects = [obj for obj in objects if obj["score"] >= thresholds[obj["label"]]] |
| 245 | + return objects |
| 246 | + |
| 247 | + |
215 | 248 | # for output bounding box post-processing |
216 | 249 | def box_cxcywh_to_xyxy(x): |
217 | 250 | """Convert rectangle format from center-x, center-y, width, height to |
|
0 commit comments