Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 1.0.5-dev0
## 1.0.5

* feat: add thread lock to prevent racing condition when instantiating singletons
* feat: parametrize edge config for `DetrImageProcessor` with env variables

## 1.0.4
Expand Down
39 changes: 31 additions & 8 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import threading
from copy import deepcopy

import numpy as np
import pytest
Expand All @@ -7,7 +9,6 @@
from transformers.models.table_transformer.modeling_table_transformer import (
TableTransformerDecoder,
)
from copy import deepcopy

import unstructured_inference.models.table_postprocess as postprocess
from unstructured_inference.models import tables
Expand Down Expand Up @@ -572,7 +573,7 @@ def test_load_table_model_raises_when_not_available(model_path):


@pytest.mark.parametrize(
"bbox1, bbox2, expected_result",
("bbox1", "bbox2", "expected_result"),
[
((0, 0, 5, 5), (2, 2, 7, 7), 0.36),
((0, 0, 0, 0), (6, 6, 10, 10), 0),
Expand Down Expand Up @@ -921,7 +922,9 @@ def test_table_prediction_output_format(
)
if output_format:
result = table_transformer.run_prediction(
example_image, result_format=output_format, ocr_tokens=mocked_ocr_tokens
example_image,
result_format=output_format,
ocr_tokens=mocked_ocr_tokens,
)
else:
result = table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens)
Expand Down Expand Up @@ -952,7 +955,9 @@ def test_table_prediction_output_format_when_wrong_type_then_value_error(
)
with pytest.raises(ValueError):
table_transformer.run_prediction(
example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens
example_image,
result_format="Wrong format",
ocr_tokens=mocked_ocr_tokens,
)


Expand Down Expand Up @@ -991,7 +996,8 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold(
thresholds, expected_object_number
thresholds,
expected_object_number,
):
objects = [
{"label": "0", "score": 0.2},
Expand All @@ -1010,7 +1016,8 @@ def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_two_classes(
thresholds, expected_object_number
thresholds,
expected_object_number,
):
objects = [
{"label": "0", "score": 0.2},
Expand Down Expand Up @@ -1800,7 +1807,7 @@ def test_compute_confidence_score_zero_division_error_handling():


@pytest.mark.parametrize(
"column_span_score, row_span_score, expected_text_to_indexes",
("column_span_score", "row_span_score", "expected_text_to_indexes"),
[
(
0.9,
Expand All @@ -1827,7 +1834,9 @@ def test_compute_confidence_score_zero_division_error_handling():
],
)
def test_subcells_filtering_when_overlapping_spanning_cells(
column_span_score, row_span_score, expected_text_to_indexes
column_span_score,
row_span_score,
expected_text_to_indexes,
):
"""
# table
Expand Down Expand Up @@ -1894,3 +1903,17 @@ def test_subcells_filtering_when_overlapping_spanning_cells(

predicted_cells_after_reorder, _ = structure_to_cells(saved_table_structure, tokens=tokens)
assert predicted_cells_after_reorder == predicted_cells


def test_model_init_is_thread_safe():
threads = []
tables.tables_agent.model = None
for i in range(5):
thread = threading.Thread(target=tables.load_agent)
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

assert tables.tables_agent.model is not None
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.5-dev0" # pragma: no cover
__version__ = "1.0.5" # pragma: no cover
8 changes: 6 additions & 2 deletions unstructured_inference/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import threading
from typing import Dict, Optional, Tuple, Type

from unstructured_inference.models.detectron2onnx import (
Expand All @@ -18,12 +19,15 @@

class Models(object):
_instance = None
_lock = threading.Lock()

def __new__(cls):
"""return an instance if one already exists otherwise create an instance"""
if cls._instance is None:
cls._instance = super(Models, cls).__new__(cls)
cls.models: Dict[str, UnstructuredModel] = {}
with cls._lock:
if cls._instance is None:
cls._instance = super(Models, cls).__new__(cls)
cls.models: Dict[str, UnstructuredModel] = {}
return cls._instance

def __contains__(self, key):
Expand Down
26 changes: 15 additions & 11 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# https://github.com/microsoft/table-transformer/blob/main/src/inference.py
# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb
import threading
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
Expand All @@ -23,20 +24,21 @@

from . import table_postprocess as postprocess

DEFAULT_MODEL = "microsoft/table-transformer-structure-recognition"


class UnstructuredTableTransformerModel(UnstructuredModel):
"""Unstructured model wrapper for table-transformer."""

_instance = None
_lock = threading.Lock()

def __init__(self):
pass

@classmethod
def instance(cls):
def __new__(cls):
"""return an instance if one already exists otherwise create an instance"""
if cls._instance is None:
cls._instance = cls.__new__(cls)
with cls._lock:
if cls._instance is None:
cls._instance = super(UnstructuredTableTransformerModel, cls).__new__(cls)
return cls._instance

def predict(
Expand Down Expand Up @@ -70,7 +72,7 @@ def initialize(
):
"""Loads the donut model using the specified parameters"""
self.device = device
self.feature_extractor = DetrImageProcessor.from_pretrained(model)
self.feature_extractor = DetrImageProcessor.from_pretrained(model, device_map=self.device)
# value not set in the configuration and needed for newer models
# https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all/discussions/1
self.feature_extractor.size["shortest_edge"] = inference_config.IMG_PROCESSOR_SHORTEST_EDGE
Expand Down Expand Up @@ -145,15 +147,17 @@ def run_prediction(
return prediction


tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel.instance()
tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel()


def load_agent():
"""Loads the Table agent."""

if not hasattr(tables_agent, "model"):
logger.info("Loading the Table agent ...")
tables_agent.initialize("microsoft/table-transformer-structure-recognition")
if getattr(tables_agent, "model", None) is None:
with tables_agent._lock:
if getattr(tables_agent, "model", None) is None:
logger.info("Loading the Table agent ...")
tables_agent.initialize(DEFAULT_MODEL)

return

Expand Down