Skip to content

Commit 02f6fcc

Browse files
authored
Fix/od initialization threadsafe (#446)
Make OD model initialization threadsafe. <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Make `get_model` thread-safe with a lock, add a multithreaded test, and bump version to 1.0.9 with changelog update. > > - **Models**: > - Thread-safe initialization in `unstructured_inference/models/base.py` by adding `models_lock` and wrapping `get_model` initialization path with a lock and double-check cache. > - **Tests**: > - Add `test_get_model_threaded` in `test_unstructured_inference/models/test_model.py` to validate concurrent `get_model` calls. > - **Versioning**: > - Bump `__version__` to `1.0.9`. > - Update `CHANGELOG.md` with 1.0.9 entry about thread-safe OD model loading. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit 7609025. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY -->
1 parent 5c352fb commit 02f6fcc

File tree

4 files changed

+72
-16
lines changed

4 files changed

+72
-16
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 1.0.9
2+
3+
* Make OD model loading thread safe
4+
15
## 1.0.8-dev2
26

37
* Enhancement: Optimized `zoom_image` (codeflash)

test_unstructured_inference/models/test_model.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import threading
23
from typing import Any
34
from unittest import mock
45

@@ -40,6 +41,49 @@ def test_get_model(monkeypatch):
4041
assert isinstance(models.get_model("yolox"), MockModel)
4142

4243

44+
def test_get_model_threaded(monkeypatch):
45+
"""Test that get_model works correctly when called from multiple threads simultaneously."""
46+
monkeypatch.setattr(models, "models", {})
47+
48+
# Results and exceptions from threads will be stored here
49+
results = []
50+
exceptions = []
51+
52+
def get_model_worker(thread_id):
53+
"""Worker function for each thread."""
54+
try:
55+
model = models.get_model("yolox")
56+
results.append((thread_id, model))
57+
except Exception as e:
58+
exceptions.append((thread_id, e))
59+
60+
# Create and start multiple threads
61+
num_threads = 10
62+
threads = []
63+
64+
with mock.patch.dict(models.model_class_map, {"yolox": MockModel}):
65+
for i in range(num_threads):
66+
thread = threading.Thread(target=get_model_worker, args=(i,))
67+
threads.append(thread)
68+
thread.start()
69+
70+
# Wait for all threads to complete
71+
for thread in threads:
72+
thread.join()
73+
74+
# Verify no exceptions occurred
75+
assert len(exceptions) == 0, f"Exceptions occurred in threads: {exceptions}"
76+
77+
# Verify all threads got results
78+
assert len(results) == num_threads, f"Expected {num_threads} results, got {len(results)}"
79+
80+
# Verify all results are MockModel instances
81+
for thread_id, model in results:
82+
assert isinstance(
83+
model, MockModel
84+
), f"Thread {thread_id} got unexpected model type: {type(model)}"
85+
86+
4387
def test_register_new_model():
4488
assert "foo" not in models.model_class_map
4589
assert "foo" not in models.model_config_map
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.8-dev2" # pragma: no cover
1+
__version__ = "1.0.9" # pragma: no cover

unstructured_inference/models/base.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def __setitem__(self, key: str, value: UnstructuredModel):
4242

4343
models: Models = Models()
4444

45+
models_lock = threading.Lock()
46+
4547

4648
def get_default_model_mappings() -> Tuple[
4749
Dict[str, Type[UnstructuredModel]],
@@ -78,24 +80,30 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
7880
if model_name in models:
7981
return models[model_name]
8082

81-
initialize_param_json = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH")
82-
if initialize_param_json is not None:
83-
with open(initialize_param_json) as fp:
84-
initialize_params = json.load(fp)
85-
label_map_int_keys = {
86-
int(key): value for key, value in initialize_params["label_map"].items()
87-
}
88-
initialize_params["label_map"] = label_map_int_keys
89-
else:
90-
if model_name in model_config_map:
91-
initialize_params = model_config_map[model_name]
83+
with models_lock:
84+
if model_name in models:
85+
return models[model_name]
86+
87+
initialize_param_json = os.environ.get(
88+
"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH"
89+
)
90+
if initialize_param_json is not None:
91+
with open(initialize_param_json) as fp:
92+
initialize_params = json.load(fp)
93+
label_map_int_keys = {
94+
int(key): value for key, value in initialize_params["label_map"].items()
95+
}
96+
initialize_params["label_map"] = label_map_int_keys
9297
else:
93-
raise UnknownModelException(f"Unknown model type: {model_name}")
98+
if model_name in model_config_map:
99+
initialize_params = model_config_map[model_name]
100+
else:
101+
raise UnknownModelException(f"Unknown model type: {model_name}")
94102

95-
model: UnstructuredModel = model_class_map[model_name]()
103+
model: UnstructuredModel = model_class_map[model_name]()
96104

97-
model.initialize(**initialize_params)
98-
models[model_name] = model
105+
model.initialize(**initialize_params)
106+
models[model_name] = model
99107
return model
100108

101109

0 commit comments

Comments
 (0)