Skip to content

Commit 147c5b1

Browse files
authored
Refactor/get model allow registering of new models with function calls (#333)
This PR allows us to register new models and use them in the system: - refactor of `get_model` so it relies on information stored in `model_class_map` and `model_config_map` to initialize a new model with a given model name (which is the key to both mappings) - a new function `unstructured_inference.models.base.register_new_model` now allows adding new model definition to the class mapping and config mapping - after calling register new model one can now call `get_model` with the new model name and get the new model ## testing New unit tests should pass
1 parent 0a08377 commit 147c5b1

File tree

6 files changed

+68
-20
lines changed

6 files changed

+68
-20
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
## 0.7.26-dev0
1+
## 0.7.26
2+
23
* feat: add a set of new `ElementType`s to extend future element types recognition
4+
* feat: allow registering of new models for inference using `unstructured_inference.models.base.register_new_model` function
35

46
## 0.7.25
57

test_unstructured_inference/models/test_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,35 @@ def predict(self, x: Any) -> Any:
2626
return []
2727

2828

29+
MOCK_MODEL_TYPES = {
30+
"foo": {
31+
"input_shape": (640, 640),
32+
},
33+
}
34+
35+
2936
def test_get_model(monkeypatch):
3037
monkeypatch.setattr(models, "models", {})
3138
with mock.patch.dict(models.model_class_map, {"checkbox": MockModel}):
3239
assert isinstance(models.get_model("checkbox"), MockModel)
3340

3441

42+
def test_register_new_model():
43+
assert "foo" not in models.model_class_map
44+
assert "foo" not in models.model_config_map
45+
models.register_new_model(MOCK_MODEL_TYPES, MockModel)
46+
assert "foo" in models.model_class_map
47+
assert "foo" in models.model_config_map
48+
model = models.get_model("foo")
49+
assert len(model.initializer.mock_calls) == 1
50+
assert model.initializer.mock_calls[0][-1] == MOCK_MODEL_TYPES["foo"]
51+
assert isinstance(model, MockModel)
52+
# unregister the new model by reset to default
53+
models.model_class_map, models.model_config_map = models.get_default_model_mappings()
54+
assert "foo" not in models.model_class_map
55+
assert "foo" not in models.model_config_map
56+
57+
3558
def test_raises_invalid_model():
3659
with pytest.raises(models.UnknownModelException):
3760
models.get_model("fake_model")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.26-dev0" # pragma: no cover
1+
__version__ = "0.7.26" # pragma: no cover

unstructured_inference/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
settings that should not be altered without making a code change (e.g., definition of 1Gb of memory
66
in bytes). Constants should go into `./constants.py`
77
"""
8+
89
import os
910
from dataclasses import dataclass
1011

unstructured_inference/inference/layoutelement.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ def merge_inferred_layout_with_extracted_layout(
165165
categorized_extracted_elements_to_add = [
166166
LayoutElement(
167167
text=el.text,
168-
type=ElementType.IMAGE
169-
if isinstance(el, ImageTextRegion)
170-
else ElementType.UNCATEGORIZED_TEXT,
168+
type=(
169+
ElementType.IMAGE
170+
if isinstance(el, ImageTextRegion)
171+
else ElementType.UNCATEGORIZED_TEXT
172+
),
171173
source=el.source,
172174
bbox=el.bbox,
173175
)

unstructured_inference/models/base.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
import json
24
import os
3-
from typing import Dict, Optional, Type
5+
from typing import Dict, Optional, Tuple, Type
46

57
from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES
68
from unstructured_inference.models.chipper import UnstructuredChipperModel
@@ -15,17 +17,41 @@
1517
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
1618
from unstructured_inference.models.yolox import MODEL_TYPES as YOLOX_MODEL_TYPES
1719
from unstructured_inference.models.yolox import UnstructuredYoloXModel
20+
from unstructured_inference.utils import LazyDict
1821

1922
DEFAULT_MODEL = "yolox"
2023

2124
models: Dict[str, UnstructuredModel] = {}
2225

23-
model_class_map: Dict[str, Type[UnstructuredModel]] = {
24-
**{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES},
25-
**{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES},
26-
**{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES},
27-
**{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES},
28-
}
26+
27+
def get_default_model_mappings() -> (
28+
Tuple[
29+
Dict[str, Type[UnstructuredModel]],
30+
Dict[str, dict | LazyDict],
31+
]
32+
):
33+
"""default model mappings for models that are in `unstructured_inference` repo"""
34+
return {
35+
**{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES},
36+
**{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES},
37+
**{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES},
38+
**{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES},
39+
}, {
40+
**DETECTRON2_MODEL_TYPES,
41+
**DETECTRON2_ONNX_MODEL_TYPES,
42+
**YOLOX_MODEL_TYPES,
43+
**CHIPPER_MODEL_TYPES,
44+
}
45+
46+
47+
model_class_map, model_config_map = get_default_model_mappings()
48+
49+
50+
def register_new_model(model_config: dict, model_class: UnstructuredModel):
51+
"""registering a new model by updating the model_config_map and model_class_map with the new
52+
model class information"""
53+
model_config_map.update(model_config)
54+
model_class_map.update({name: model_class for name in model_config})
2955

3056

3157
def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
@@ -51,14 +77,8 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
5177
}
5278
initialize_params["label_map"] = label_map_int_keys
5379
else:
54-
if model_name in DETECTRON2_MODEL_TYPES:
55-
initialize_params = DETECTRON2_MODEL_TYPES[model_name]
56-
elif model_name in DETECTRON2_ONNX_MODEL_TYPES:
57-
initialize_params = DETECTRON2_ONNX_MODEL_TYPES[model_name]
58-
elif model_name in YOLOX_MODEL_TYPES:
59-
initialize_params = YOLOX_MODEL_TYPES[model_name]
60-
elif model_name in CHIPPER_MODEL_TYPES:
61-
initialize_params = CHIPPER_MODEL_TYPES[model_name]
80+
if model_name in model_config_map:
81+
initialize_params = model_config_map[model_name]
6282
else:
6383
raise UnknownModelException(f"Unknown model type: {model_name}")
6484

0 commit comments

Comments
 (0)