Skip to content

Commit 8407827

Browse files
authored
feat: use singleton instead of global (#429)
- where previously global was used now they are singleton class variables - this avoids issues with multi-threading env
1 parent 8cf6993 commit 8407827

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 1.0.4
2+
3+
* feat: use singleton instead of `global` to store shared variables
4+
15
## 1.0.3
26

37
* setting longest_edge=1333 to the table image processor
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.3" # pragma: no cover
1+
__version__ = "1.0.4" # pragma: no cover

unstructured_inference/models/base.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,28 @@
1515

1616
DEFAULT_MODEL = "yolox"
1717

18-
models: Dict[str, UnstructuredModel] = {}
18+
19+
class Models(object):
20+
_instance = None
21+
22+
def __new__(cls):
23+
"""return an instance if one already exists otherwise create an instance"""
24+
if cls._instance is None:
25+
cls._instance = super(Models, cls).__new__(cls)
26+
cls.models: Dict[str, UnstructuredModel] = {}
27+
return cls._instance
28+
29+
def __contains__(self, key):
30+
return key in self.models
31+
32+
def __getitem__(self, key: str):
33+
return self.models.__getitem__(key)
34+
35+
def __setitem__(self, key: str, value: UnstructuredModel):
36+
self.models[key] = value
37+
38+
39+
models: Models = Models()
1940

2041

2142
def get_default_model_mappings() -> Tuple[
@@ -46,8 +67,6 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
4667
# TODO(alan): These cases are similar enough that we can probably do them all together with
4768
# importlib
4869

49-
global models # noqa
50-
5170
if model_name is None:
5271
default_name_from_env = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_NAME")
5372
model_name = default_name_from_env if default_name_from_env is not None else DEFAULT_MODEL

unstructured_inference/models/tables.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,18 @@
2727
class UnstructuredTableTransformerModel(UnstructuredModel):
2828
"""Unstructured model wrapper for table-transformer."""
2929

30+
_instance = None
31+
3032
def __init__(self):
3133
pass
3234

35+
@classmethod
36+
def instance(cls):
37+
"""return an instance if one already exists otherwise create an instance"""
38+
if cls._instance is None:
39+
cls._instance = cls.__new__(cls)
40+
return cls._instance
41+
3342
def predict(
3443
self,
3544
x: PILImage.Image,
@@ -72,7 +81,8 @@ def initialize(
7281
cached_current_verbosity = logging.get_verbosity()
7382
logging.set_verbosity_error()
7483
self.model = TableTransformerForObjectDetection.from_pretrained(
75-
model, device_map=self.device
84+
model,
85+
device_map=self.device,
7686
)
7787
logging.set_verbosity(cached_current_verbosity)
7888
self.model.eval()
@@ -135,12 +145,11 @@ def run_prediction(
135145
return prediction
136146

137147

138-
tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel()
148+
tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel.instance()
139149

140150

141151
def load_agent():
142-
"""Loads the Table agent as a global variable to ensure that we only load it once."""
143-
global tables_agent # noqa
152+
"""Loads the Table agent."""
144153

145154
if not hasattr(tables_agent, "model"):
146155
logger.info("Loading the Table agent ...")

0 commit comments

Comments
 (0)