|
6 | 6 | import os |
7 | 7 | import threading |
8 | 8 | from collections.abc import Iterable |
9 | | -from typing import Set, Union |
| 9 | +from typing import Dict, List, Set, Union |
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | import torch |
13 | | -import torchvision.transforms as T |
14 | 13 | from PIL import Image |
15 | | -from transformers import RTDetrForObjectDetection, RTDetrImageProcessor |
| 14 | +from torch import Tensor |
| 15 | +from transformers import AutoModelForObjectDetection, RTDetrImageProcessor |
| 16 | + |
| 17 | +from docling_ibm_models.layoutmodel.labels import LayoutLabels |
16 | 18 |
|
17 | 19 | _log = logging.getLogger(__name__) |
18 | 20 |
|
@@ -46,70 +48,67 @@ def __init__( |
46 | 48 | ------ |
47 | 49 | FileNotFoundError when the model's torch file is missing |
48 | 50 | """ |
49 | | - # Initialize classes map: |
50 | | - self._classes_map = { |
51 | | - 0: "background", |
52 | | - 1: "Caption", |
53 | | - 2: "Footnote", |
54 | | - 3: "Formula", |
55 | | - 4: "List-item", |
56 | | - 5: "Page-footer", |
57 | | - 6: "Page-header", |
58 | | - 7: "Picture", |
59 | | - 8: "Section-header", |
60 | | - 9: "Table", |
61 | | - 10: "Text", |
62 | | - 11: "Title", |
63 | | - 12: "Document Index", |
64 | | - 13: "Code", |
65 | | - 14: "Checkbox-Selected", |
66 | | - 15: "Checkbox-Unselected", |
67 | | - 16: "Form", |
68 | | - 17: "Key-Value Region", |
69 | | - } |
70 | | - |
71 | 51 | # Blacklisted classes |
72 | 52 | self._black_classes = blacklist_classes # set(["Form", "Key-Value Region"]) |
73 | 53 |
|
| 54 | + # Canonical classes |
| 55 | + self._labels = LayoutLabels() |
| 56 | + |
74 | 57 | # Set basic params |
75 | 58 | self._threshold = base_threshold # Score threshold |
76 | | - self._image_size = 640 |
77 | | - self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64) |
78 | 59 |
|
79 | 60 | # Set number of threads for CPU |
80 | 61 | self._device = torch.device(device) |
81 | 62 | self._num_threads = num_threads |
82 | 63 | if device == "cpu": |
83 | 64 | torch.set_num_threads(self._num_threads) |
84 | 65 |
|
85 | | - # Model file and configurations |
| 66 | + # Load model file and configurations |
| 67 | + self._processor_config = os.path.join(artifact_path, "preprocessor_config.json") |
| 68 | + self._model_config = os.path.join(artifact_path, "config.json") |
86 | 69 | self._st_fn = os.path.join(artifact_path, "model.safetensors") |
87 | 70 | if not os.path.isfile(self._st_fn): |
88 | 71 | raise FileNotFoundError("Missing safe tensors file: {}".format(self._st_fn)) |
| 72 | + if not os.path.isfile(self._processor_config): |
| 73 | + raise FileNotFoundError( |
| 74 | + f"Missing processor config file: {self._processor_config}" |
| 75 | + ) |
| 76 | + if not os.path.isfile(self._model_config): |
| 77 | + raise FileNotFoundError(f"Missing model config file: {self._model_config}") |
89 | 78 |
|
90 | 79 | # Load model and move to device |
91 | | - processor_config = os.path.join(artifact_path, "preprocessor_config.json") |
92 | | - model_config = os.path.join(artifact_path, "config.json") |
93 | | - self._image_processor = RTDetrImageProcessor.from_json_file(processor_config) |
| 80 | + self._image_processor = RTDetrImageProcessor.from_json_file( |
| 81 | + self._processor_config |
| 82 | + ) |
94 | 83 |
|
95 | 84 | # Use lock to prevent threading issues during model initialization |
96 | 85 | with _model_init_lock: |
97 | | - self._model = RTDetrForObjectDetection.from_pretrained( |
98 | | - artifact_path, config=model_config |
| 86 | + self._model = AutoModelForObjectDetection.from_pretrained( |
| 87 | + artifact_path, config=self._model_config |
99 | 88 | ).to(self._device) |
100 | 89 | self._model.eval() |
101 | 90 |
|
| 91 | + # Set classes map |
| 92 | + self._model_name = type(self._model).__name__ |
| 93 | + if self._model_name == "RTDetrForObjectDetection": |
| 94 | + self._classes_map = self._labels.shifted_canonical_categories() |
| 95 | + self._label_offset = 1 |
| 96 | + else: |
| 97 | + self._classes_map = self._labels.canonical_categories() |
| 98 | + self._label_offset = 0 |
| 99 | + |
102 | 100 | _log.debug("LayoutPredictor settings: {}".format(self.info())) |
103 | 101 |
|
104 | 102 | def info(self) -> dict: |
105 | 103 | """ |
106 | 104 | Get information about the configuration of LayoutPredictor |
107 | 105 | """ |
108 | 106 | info = { |
| 107 | + "model_name": self._model_name, |
109 | 108 | "safe_tensors_file": self._st_fn, |
110 | 109 | "device": self._device.type, |
111 | 110 | "num_threads": self._num_threads, |
112 | | - "image_size": self._image_size, |
| 111 | + "image_size": self._image_processor.size, |
113 | 112 | "threshold": self._threshold, |
114 | 113 | } |
115 | 114 | return info |
@@ -141,28 +140,27 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]: |
141 | 140 | else: |
142 | 141 | raise TypeError("Not supported input image format") |
143 | 142 |
|
144 | | - resize = {"height": self._image_size, "width": self._image_size} |
145 | | - inputs = self._image_processor( |
146 | | - images=page_img, |
147 | | - return_tensors="pt", |
148 | | - size=resize, |
149 | | - ).to(self._device) |
| 143 | + target_sizes = torch.tensor([page_img.size[::-1]]) |
| 144 | + inputs = self._image_processor(images=[page_img], return_tensors="pt").to( |
| 145 | + self._device |
| 146 | + ) |
150 | 147 | outputs = self._model(**inputs) |
151 | | - results = self._image_processor.post_process_object_detection( |
152 | | - outputs, |
153 | | - target_sizes=torch.tensor([page_img.size[::-1]]), |
154 | | - threshold=self._threshold, |
| 148 | + results: List[Dict[str, Tensor]] = ( |
| 149 | + self._image_processor.post_process_object_detection( |
| 150 | + outputs, |
| 151 | + target_sizes=target_sizes, |
| 152 | + threshold=self._threshold, |
| 153 | + ) |
155 | 154 | ) |
156 | 155 |
|
157 | 156 | w, h = page_img.size |
158 | | - |
159 | 157 | result = results[0] |
160 | 158 | for score, label_id, box in zip( |
161 | 159 | result["scores"], result["labels"], result["boxes"] |
162 | 160 | ): |
163 | 161 | score = float(score.item()) |
164 | 162 |
|
165 | | - label_id = int(label_id.item()) + 1 # Advance the label_id |
| 163 | + label_id = int(label_id.item()) + self._label_offset |
166 | 164 | label_str = self._classes_map[label_id] |
167 | 165 |
|
168 | 166 | # Filter out blacklisted classes |
|
0 commit comments