Skip to content

Commit 8913620

Browse files
authored
enhancement: cache named models (#155)
Fixes the memory issue that was discovered in #152. * Caches named models as they are loaded to prevent reloading or creating multiple copies. * Adds test that fails if model is initialized more than once (fails on main). * This should also address concerns that Chipper was being loaded multiple times.
1 parent d185e64 commit 8913620

File tree

6 files changed

+53
-8
lines changed

6 files changed

+53
-8
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
## 0.5.8-dev0
1+
## 0.5.8-dev1
2+
3+
* Cache named models that have been lodaed
24

35
## 0.5.7
46

test_unstructured_inference/models/test_detectron2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@ def detect(self, x):
1717

1818
def test_load_default_model(monkeypatch):
1919
monkeypatch.setattr(detectron2, "Detectron2LayoutModel", MockDetectron2LayoutModel)
20+
monkeypatch.setattr(models, "models", {})
2021

2122
with patch.object(detectron2, "is_detectron2_available", return_value=True):
2223
model = models.get_model("detectron2_lp")
2324

2425
assert isinstance(model.model, MockDetectron2LayoutModel)
2526

2627

27-
def test_load_default_model_raises_when_not_available():
28+
def test_load_default_model_raises_when_not_available(monkeypatch):
29+
monkeypatch.setattr(models, "models", {})
2830
with patch.object(detectron2, "is_detectron2_available", return_value=False):
2931
with pytest.raises(ImportError):
3032
models.get_model("detectron2_lp")

test_unstructured_inference/models/test_detectron2onnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class input_thing:
2323
return [input_thing()]
2424

2525

26-
def test_load_default_model():
26+
def test_load_default_model(monkeypatch):
27+
monkeypatch.setattr(models, "models", {})
2728
with patch.object(
2829
detectron2.onnxruntime,
2930
"InferenceSession",

test_unstructured_inference/models/test_model.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,33 @@
1+
from typing import Any
2+
from unittest import mock
3+
14
import pytest
25

36
import unstructured_inference.models.base as models
4-
from unstructured_inference.models.unstructuredmodel import ModelNotInitializedError
7+
from unstructured_inference.models.unstructuredmodel import (
8+
ModelNotInitializedError,
9+
UnstructuredObjectDetectionModel,
10+
)
11+
12+
13+
class MockModel(UnstructuredObjectDetectionModel):
14+
call_count = 0
515

16+
initialize = mock.MagicMock()
17+
18+
def __init__(self):
19+
self.initializer = mock.MagicMock()
20+
super().__init__()
621

7-
class MockModel:
822
def initialize(self, *args, **kwargs):
9-
pass
23+
return self.initializer(self, *args, **kwargs)
24+
25+
def predict(self, x: Any) -> Any:
26+
return []
1027

1128

1229
def test_get_model(monkeypatch):
30+
monkeypatch.setattr(models, "models", {})
1331
monkeypatch.setattr(
1432
models,
1533
"UnstructuredDetectronModel",
@@ -36,3 +54,16 @@ def test_raises_invalid_model():
3654
def test_raises_uninitialized():
3755
with pytest.raises(ModelNotInitializedError):
3856
models.UnstructuredDetectronModel().predict(None)
57+
58+
59+
def test_model_initializes_once():
60+
from unstructured_inference.inference import layout
61+
62+
with mock.patch.object(models, "UnstructuredDetectronONNXModel", MockModel), mock.patch.object(
63+
models,
64+
"models",
65+
{},
66+
):
67+
doc = layout.DocumentLayout.from_file("sample-docs/layout-parser-paper.pdf")
68+
69+
doc.pages[0].detection_model.initializer.assert_called_once()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.8-dev0" # pragma: no cover
1+
__version__ = "0.5.8-dev1" # pragma: no cover

unstructured_inference/models/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Dict, Optional
22

33
from unstructured_inference.logger import logger
44
from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES
@@ -25,14 +25,22 @@
2525

2626
DEFAULT_MODEL = "detectron2_onnx"
2727

28+
models: Dict[str, UnstructuredModel] = {}
29+
2830

2931
def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
3032
"""Gets the model object by model name."""
3133
# TODO(alan): These cases are similar enough that we can probably do them all together with
3234
# importlib
35+
36+
global models
37+
3338
if model_name is None:
3439
model_name = DEFAULT_MODEL
3540

41+
if model_name in models:
42+
return models[model_name]
43+
3644
if model_name in DETECTRON2_MODEL_TYPES:
3745
model: UnstructuredModel = UnstructuredDetectronModel()
3846
model.initialize(**DETECTRON2_MODEL_TYPES[model_name])
@@ -55,6 +63,7 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
5563
model.initialize(**CHIPPER_MODEL_TYPES[model_name])
5664
else:
5765
raise UnknownModelException(f"Unknown model type: {model_name}")
66+
models[model_name] = model
5867
return model
5968

6069

0 commit comments

Comments
 (0)