diff --git a/CHANGELOG.md b/CHANGELOG.md index b11d8ec8..f4f991bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 1.0.4 + +* feat: use singleton instead of `global` to store shared variables + ## 1.0.3 * setting longest_edge=1333 to the table image processor diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 6520c47e..18934c58 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "1.0.3" # pragma: no cover +__version__ = "1.0.4" # pragma: no cover diff --git a/unstructured_inference/models/base.py b/unstructured_inference/models/base.py index 826bf5a5..37344100 100644 --- a/unstructured_inference/models/base.py +++ b/unstructured_inference/models/base.py @@ -15,7 +15,28 @@ DEFAULT_MODEL = "yolox" -models: Dict[str, UnstructuredModel] = {} + +class Models(object): + _instance = None + + def __new__(cls): + """return an instance if one already exists otherwise create an instance""" + if cls._instance is None: + cls._instance = super(Models, cls).__new__(cls) + cls.models: Dict[str, UnstructuredModel] = {} + return cls._instance + + def __contains__(self, key): + return key in self.models + + def __getitem__(self, key: str): + return self.models.__getitem__(key) + + def __setitem__(self, key: str, value: UnstructuredModel): + self.models[key] = value + + +models: Models = Models() def get_default_model_mappings() -> Tuple[ @@ -46,8 +67,6 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel: # TODO(alan): These cases are similar enough that we can probably do them all together with # importlib - global models # noqa - if model_name is None: default_name_from_env = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_NAME") model_name = default_name_from_env if default_name_from_env is not None else DEFAULT_MODEL diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index 45c98017..9760dfc4 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -27,9 +27,18 @@ class UnstructuredTableTransformerModel(UnstructuredModel): """Unstructured model wrapper for table-transformer.""" + _instance = None + def __init__(self): pass + @classmethod + def instance(cls): + """return an instance if one already exists otherwise create an instance""" + if cls._instance is None: + cls._instance = cls.__new__(cls) + return cls._instance + def predict( self, x: PILImage.Image, @@ -72,7 +81,8 @@ def initialize( cached_current_verbosity = logging.get_verbosity() logging.set_verbosity_error() self.model = TableTransformerForObjectDetection.from_pretrained( - model, device_map=self.device + model, + device_map=self.device, ) logging.set_verbosity(cached_current_verbosity) self.model.eval() @@ -135,12 +145,11 @@ def run_prediction( return prediction -tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel() +tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel.instance() def load_agent(): - """Loads the Table agent as a global variable to ensure that we only load it once.""" - global tables_agent # noqa + """Loads the Table agent.""" if not hasattr(tables_agent, "model"): logger.info("Loading the Table agent ...")