11import os
2+ import threading
3+ from copy import deepcopy
24
35import numpy as np
46import pytest
79from transformers .models .table_transformer .modeling_table_transformer import (
810 TableTransformerDecoder ,
911)
10- from copy import deepcopy
1112
1213import unstructured_inference .models .table_postprocess as postprocess
1314from unstructured_inference .models import tables
@@ -572,7 +573,7 @@ def test_load_table_model_raises_when_not_available(model_path):
572573
573574
574575@pytest .mark .parametrize (
575- "bbox1, bbox2, expected_result" ,
576+ ( "bbox1" , " bbox2" , " expected_result") ,
576577 [
577578 ((0 , 0 , 5 , 5 ), (2 , 2 , 7 , 7 ), 0.36 ),
578579 ((0 , 0 , 0 , 0 ), (6 , 6 , 10 , 10 ), 0 ),
@@ -921,7 +922,9 @@ def test_table_prediction_output_format(
921922 )
922923 if output_format :
923924 result = table_transformer .run_prediction (
924- example_image , result_format = output_format , ocr_tokens = mocked_ocr_tokens
925+ example_image ,
926+ result_format = output_format ,
927+ ocr_tokens = mocked_ocr_tokens ,
925928 )
926929 else :
927930 result = table_transformer .run_prediction (example_image , ocr_tokens = mocked_ocr_tokens )
@@ -952,7 +955,9 @@ def test_table_prediction_output_format_when_wrong_type_then_value_error(
952955 )
953956 with pytest .raises (ValueError ):
954957 table_transformer .run_prediction (
955- example_image , result_format = "Wrong format" , ocr_tokens = mocked_ocr_tokens
958+ example_image ,
959+ result_format = "Wrong format" ,
960+ ocr_tokens = mocked_ocr_tokens ,
956961 )
957962
958963
@@ -991,7 +996,8 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
991996 ],
992997)
993998def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold (
994- thresholds , expected_object_number
999+ thresholds ,
1000+ expected_object_number ,
9951001):
9961002 objects = [
9971003 {"label" : "0" , "score" : 0.2 },
@@ -1010,7 +1016,8 @@ def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_
10101016 ],
10111017)
10121018def test_objects_are_filtered_based_on_class_thresholds_when_two_classes (
1013- thresholds , expected_object_number
1019+ thresholds ,
1020+ expected_object_number ,
10141021):
10151022 objects = [
10161023 {"label" : "0" , "score" : 0.2 },
@@ -1800,7 +1807,7 @@ def test_compute_confidence_score_zero_division_error_handling():
18001807
18011808
18021809@pytest .mark .parametrize (
1803- "column_span_score, row_span_score, expected_text_to_indexes" ,
1810+ ( "column_span_score" , " row_span_score" , " expected_text_to_indexes") ,
18041811 [
18051812 (
18061813 0.9 ,
@@ -1827,7 +1834,9 @@ def test_compute_confidence_score_zero_division_error_handling():
18271834 ],
18281835)
18291836def test_subcells_filtering_when_overlapping_spanning_cells (
1830- column_span_score , row_span_score , expected_text_to_indexes
1837+ column_span_score ,
1838+ row_span_score ,
1839+ expected_text_to_indexes ,
18311840):
18321841 """
18331842 # table
@@ -1894,3 +1903,17 @@ def test_subcells_filtering_when_overlapping_spanning_cells(
18941903
18951904 predicted_cells_after_reorder , _ = structure_to_cells (saved_table_structure , tokens = tokens )
18961905 assert predicted_cells_after_reorder == predicted_cells
1906+
1907+
1908+ def test_model_init_is_thread_safe ():
1909+ threads = []
1910+ tables .tables_agent .model = None
1911+ for i in range (5 ):
1912+ thread = threading .Thread (target = tables .load_agent )
1913+ threads .append (thread )
1914+ thread .start ()
1915+
1916+ for thread in threads :
1917+ thread .join ()
1918+
1919+ assert tables .tables_agent .model is not None
0 commit comments