|
28 | 28 | from unstructured_inference.inference.ordering import order_layout |
29 | 29 | from unstructured_inference.logger import logger |
30 | 30 | from unstructured_inference.models.base import get_model |
| 31 | +from unstructured_inference.models.detectron2onnx import ( |
| 32 | + UnstructuredDetectronONNXModel, |
| 33 | +) |
31 | 34 | from unstructured_inference.models.unstructuredmodel import ( |
32 | 35 | UnstructuredElementExtractionModel, |
33 | 36 | UnstructuredObjectDetectionModel, |
@@ -262,15 +265,33 @@ def get_elements_with_detection_model( |
262 | 265 | raise ValueError("Invalid OCR mode") |
263 | 266 |
|
264 | 267 | if self.layout is not None: |
| 268 | + threshold_kwargs = {} |
| 269 | + # NOTE(Benjamin): With this the thresholds are only changed for detextron2_mask_rcnn |
| 270 | + # In other case the default values for the functions are used |
| 271 | + if ( |
| 272 | + isinstance(self.detection_model, UnstructuredDetectronONNXModel) |
| 273 | + and "R_50" not in self.detection_model.model_path |
| 274 | + ): |
| 275 | + threshold_kwargs = {"same_region_threshold": 0.5, "subregion_threshold": 0.5} |
265 | 276 | inferred_layout = merge_inferred_layout_with_extracted_layout( |
266 | 277 | inferred_layout=inferred_layout, |
267 | 278 | extracted_layout=self.layout, |
268 | 279 | ocr_layout=ocr_layout, |
| 280 | + **threshold_kwargs, |
269 | 281 | ) |
270 | 282 | elif ocr_layout is not None: |
| 283 | + threshold_kwargs = {} |
| 284 | + # NOTE(Benjamin): With this the thresholds are only changed for detextron2_mask_rcnn |
| 285 | + # In other case the default values for the functions are used |
| 286 | + if ( |
| 287 | + isinstance(self.detection_model, UnstructuredDetectronONNXModel) |
| 288 | + and "R_50" not in self.detection_model.model_path |
| 289 | + ): |
| 290 | + threshold_kwargs = {"subregion_threshold": 0.3} |
271 | 291 | inferred_layout = merge_inferred_layout_with_ocr_layout( |
272 | 292 | inferred_layout=inferred_layout, |
273 | 293 | ocr_layout=ocr_layout, |
| 294 | + **threshold_kwargs, |
274 | 295 | ) |
275 | 296 |
|
276 | 297 | elements = self.get_elements_from_layout(cast(List[TextRegion], inferred_layout)) |
|
0 commit comments