Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 1.0.5

* feat: add thread lock to prevent racing condition when instantiating singletons

## 1.0.4

* feat: use singleton instead of `global` to store shared variables
Expand Down
39 changes: 31 additions & 8 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import threading
from copy import deepcopy

import numpy as np
import pytest
Expand All @@ -7,7 +9,6 @@
from transformers.models.table_transformer.modeling_table_transformer import (
TableTransformerDecoder,
)
from copy import deepcopy

import unstructured_inference.models.table_postprocess as postprocess
from unstructured_inference.models import tables
Expand Down Expand Up @@ -572,7 +573,7 @@ def test_load_table_model_raises_when_not_available(model_path):


@pytest.mark.parametrize(
"bbox1, bbox2, expected_result",
("bbox1", "bbox2", "expected_result"),
[
((0, 0, 5, 5), (2, 2, 7, 7), 0.36),
((0, 0, 0, 0), (6, 6, 10, 10), 0),
Expand Down Expand Up @@ -921,7 +922,9 @@ def test_table_prediction_output_format(
)
if output_format:
result = table_transformer.run_prediction(
example_image, result_format=output_format, ocr_tokens=mocked_ocr_tokens
example_image,
result_format=output_format,
ocr_tokens=mocked_ocr_tokens,
)
else:
result = table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens)
Expand Down Expand Up @@ -952,7 +955,9 @@ def test_table_prediction_output_format_when_wrong_type_then_value_error(
)
with pytest.raises(ValueError):
table_transformer.run_prediction(
example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens
example_image,
result_format="Wrong format",
ocr_tokens=mocked_ocr_tokens,
)


Expand Down Expand Up @@ -991,7 +996,8 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold(
thresholds, expected_object_number
thresholds,
expected_object_number,
):
objects = [
{"label": "0", "score": 0.2},
Expand All @@ -1010,7 +1016,8 @@ def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_two_classes(
thresholds, expected_object_number
thresholds,
expected_object_number,
):
objects = [
{"label": "0", "score": 0.2},
Expand Down Expand Up @@ -1800,7 +1807,7 @@ def test_compute_confidence_score_zero_division_error_handling():


@pytest.mark.parametrize(
"column_span_score, row_span_score, expected_text_to_indexes",
("column_span_score", "row_span_score", "expected_text_to_indexes"),
[
(
0.9,
Expand All @@ -1827,7 +1834,9 @@ def test_compute_confidence_score_zero_division_error_handling():
],
)
def test_subcells_filtering_when_overlapping_spanning_cells(
column_span_score, row_span_score, expected_text_to_indexes
column_span_score,
row_span_score,
expected_text_to_indexes,
):
"""
# table
Expand Down Expand Up @@ -1894,3 +1903,17 @@ def test_subcells_filtering_when_overlapping_spanning_cells(

predicted_cells_after_reorder, _ = structure_to_cells(saved_table_structure, tokens=tokens)
assert predicted_cells_after_reorder == predicted_cells


def test_model_init_is_thread_safe():
threads = []
tables.tables_agent.model = None
for i in range(5):
thread = threading.Thread(target=tables.load_agent)
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

assert tables.tables_agent.model is not None
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.4" # pragma: no cover
__version__ = "1.0.5" # pragma: no cover
8 changes: 6 additions & 2 deletions unstructured_inference/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import threading
from typing import Dict, Optional, Tuple, Type

from unstructured_inference.models.detectron2onnx import (
Expand All @@ -18,12 +19,15 @@

class Models(object):
_instance = None
_lock = threading.Lock()

def __new__(cls):
"""return an instance if one already exists otherwise create an instance"""
if cls._instance is None:
cls._instance = super(Models, cls).__new__(cls)
cls.models: Dict[str, UnstructuredModel] = {}
with cls._lock:
if cls._instance is None:
cls._instance = super(Models, cls).__new__(cls)
cls.models: Dict[str, UnstructuredModel] = {}
return cls._instance

def __contains__(self, key):
Expand Down
26 changes: 15 additions & 11 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# https://github.com/microsoft/table-transformer/blob/main/src/inference.py
# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb
import threading
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
Expand All @@ -23,20 +24,21 @@

from . import table_postprocess as postprocess

DEFAULT_MODEL = "microsoft/table-transformer-structure-recognition"


class UnstructuredTableTransformerModel(UnstructuredModel):
"""Unstructured model wrapper for table-transformer."""

_instance = None
_lock = threading.Lock()

def __init__(self):
pass

@classmethod
def instance(cls):
def __new__(cls):
"""return an instance if one already exists otherwise create an instance"""
if cls._instance is None:
cls._instance = cls.__new__(cls)
with cls._lock:
if cls._instance is None:
cls._instance = super(UnstructuredTableTransformerModel, cls).__new__(cls)
return cls._instance

def predict(
Expand Down Expand Up @@ -70,7 +72,7 @@ def initialize(
):
"""Loads the donut model using the specified parameters"""
self.device = device
self.feature_extractor = DetrImageProcessor.from_pretrained(model)
self.feature_extractor = DetrImageProcessor.from_pretrained(model, device_map=self.device)
# value not set in the configuration and needed for newer models
# https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all/discussions/1
self.feature_extractor.size["shortest_edge"] = 800
Expand Down Expand Up @@ -145,15 +147,17 @@ def run_prediction(
return prediction


tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel.instance()
tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel()


def load_agent():
"""Loads the Table agent."""

if not hasattr(tables_agent, "model"):
logger.info("Loading the Table agent ...")
tables_agent.initialize("microsoft/table-transformer-structure-recognition")
if getattr(tables_agent, "model", None) is None:
with tables_agent._lock:
if getattr(tables_agent, "model", None) is None:
logger.info("Loading the Table agent ...")
tables_agent.initialize(DEFAULT_MODEL)

return

Expand Down