diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f4e35e8..8ea6dae6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ -## 1.0.5-dev0 +## 1.0.5 +* feat: add thread lock to prevent racing condition when instantiating singletons * feat: parametrize edge config for `DetrImageProcessor` with env variables ## 1.0.4 diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 03c9b1fd..65a71729 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -1,4 +1,6 @@ import os +import threading +from copy import deepcopy import numpy as np import pytest @@ -7,7 +9,6 @@ from transformers.models.table_transformer.modeling_table_transformer import ( TableTransformerDecoder, ) -from copy import deepcopy import unstructured_inference.models.table_postprocess as postprocess from unstructured_inference.models import tables @@ -572,7 +573,7 @@ def test_load_table_model_raises_when_not_available(model_path): @pytest.mark.parametrize( - "bbox1, bbox2, expected_result", + ("bbox1", "bbox2", "expected_result"), [ ((0, 0, 5, 5), (2, 2, 7, 7), 0.36), ((0, 0, 0, 0), (6, 6, 10, 10), 0), @@ -921,7 +922,9 @@ def test_table_prediction_output_format( ) if output_format: result = table_transformer.run_prediction( - example_image, result_format=output_format, ocr_tokens=mocked_ocr_tokens + example_image, + result_format=output_format, + ocr_tokens=mocked_ocr_tokens, ) else: result = table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens) @@ -952,7 +955,9 @@ def test_table_prediction_output_format_when_wrong_type_then_value_error( ) with pytest.raises(ValueError): table_transformer.run_prediction( - example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens + example_image, + result_format="Wrong format", + ocr_tokens=mocked_ocr_tokens, ) @@ -991,7 +996,8 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image): ], ) def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold( - thresholds, expected_object_number + thresholds, + expected_object_number, ): objects = [ {"label": "0", "score": 0.2}, @@ -1010,7 +1016,8 @@ def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_ ], ) def test_objects_are_filtered_based_on_class_thresholds_when_two_classes( - thresholds, expected_object_number + thresholds, + expected_object_number, ): objects = [ {"label": "0", "score": 0.2}, @@ -1800,7 +1807,7 @@ def test_compute_confidence_score_zero_division_error_handling(): @pytest.mark.parametrize( - "column_span_score, row_span_score, expected_text_to_indexes", + ("column_span_score", "row_span_score", "expected_text_to_indexes"), [ ( 0.9, @@ -1827,7 +1834,9 @@ def test_compute_confidence_score_zero_division_error_handling(): ], ) def test_subcells_filtering_when_overlapping_spanning_cells( - column_span_score, row_span_score, expected_text_to_indexes + column_span_score, + row_span_score, + expected_text_to_indexes, ): """ # table @@ -1894,3 +1903,17 @@ def test_subcells_filtering_when_overlapping_spanning_cells( predicted_cells_after_reorder, _ = structure_to_cells(saved_table_structure, tokens=tokens) assert predicted_cells_after_reorder == predicted_cells + + +def test_model_init_is_thread_safe(): + threads = [] + tables.tables_agent.model = None + for i in range(5): + thread = threading.Thread(target=tables.load_agent) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert tables.tables_agent.model is not None diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 638333d9..f7b35997 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "1.0.5-dev0" # pragma: no cover +__version__ = "1.0.5" # pragma: no cover diff --git a/unstructured_inference/models/base.py b/unstructured_inference/models/base.py index 37344100..a393df7e 100644 --- a/unstructured_inference/models/base.py +++ b/unstructured_inference/models/base.py @@ -2,6 +2,7 @@ import json import os +import threading from typing import Dict, Optional, Tuple, Type from unstructured_inference.models.detectron2onnx import ( @@ -18,12 +19,15 @@ class Models(object): _instance = None + _lock = threading.Lock() def __new__(cls): """return an instance if one already exists otherwise create an instance""" if cls._instance is None: - cls._instance = super(Models, cls).__new__(cls) - cls.models: Dict[str, UnstructuredModel] = {} + with cls._lock: + if cls._instance is None: + cls._instance = super(Models, cls).__new__(cls) + cls.models: Dict[str, UnstructuredModel] = {} return cls._instance def __contains__(self, key): diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index 64d2929d..c994207b 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -1,5 +1,6 @@ # https://github.com/microsoft/table-transformer/blob/main/src/inference.py # https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb +import threading import xml.etree.ElementTree as ET from collections import defaultdict from pathlib import Path @@ -23,20 +24,21 @@ from . import table_postprocess as postprocess +DEFAULT_MODEL = "microsoft/table-transformer-structure-recognition" + class UnstructuredTableTransformerModel(UnstructuredModel): """Unstructured model wrapper for table-transformer.""" _instance = None + _lock = threading.Lock() - def __init__(self): - pass - - @classmethod - def instance(cls): + def __new__(cls): """return an instance if one already exists otherwise create an instance""" if cls._instance is None: - cls._instance = cls.__new__(cls) + with cls._lock: + if cls._instance is None: + cls._instance = super(UnstructuredTableTransformerModel, cls).__new__(cls) return cls._instance def predict( @@ -70,7 +72,7 @@ def initialize( ): """Loads the donut model using the specified parameters""" self.device = device - self.feature_extractor = DetrImageProcessor.from_pretrained(model) + self.feature_extractor = DetrImageProcessor.from_pretrained(model, device_map=self.device) # value not set in the configuration and needed for newer models # https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all/discussions/1 self.feature_extractor.size["shortest_edge"] = inference_config.IMG_PROCESSOR_SHORTEST_EDGE @@ -145,15 +147,17 @@ def run_prediction( return prediction -tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel.instance() +tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel() def load_agent(): """Loads the Table agent.""" - if not hasattr(tables_agent, "model"): - logger.info("Loading the Table agent ...") - tables_agent.initialize("microsoft/table-transformer-structure-recognition") + if getattr(tables_agent, "model", None) is None: + with tables_agent._lock: + if getattr(tables_agent, "model", None) is None: + logger.info("Loading the Table agent ...") + tables_agent.initialize(DEFAULT_MODEL) return