Skip to content

Commit bfef09c

Browse files
authored
fix: Secure torch model inits with global locks (#120)
Signed-off-by: Christoph Auer <[email protected]>
1 parent a7fa2b8 commit bfef09c

File tree

4 files changed

+83
-54
lines changed

4 files changed

+83
-54
lines changed

docling_ibm_models/code_formula_model/code_formula_predictor.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MIT
44
#
55
import logging
6+
import threading
67
from typing import List, Optional, Union
78

89
import numpy as np
@@ -17,6 +18,9 @@
1718

1819
_log = logging.getLogger(__name__)
1920

21+
# Global lock for model initialization to prevent threading issues
22+
_model_init_lock = threading.Lock()
23+
2024

2125
class StopOnString(StoppingCriteria):
2226
def __init__(self, tokenizer, stop_string):
@@ -80,13 +84,17 @@ def __init__(
8084
if device == "cpu":
8185
torch.set_num_threads(self._num_threads)
8286

83-
self._tokenizer = AutoTokenizer.from_pretrained(
84-
artifacts_path, use_fast=True, padding_side="left"
85-
)
86-
self._model = SamOPTForCausalLM.from_pretrained(artifacts_path).to(self._device)
87-
self._model.eval()
87+
# Use lock to prevent threading issues during model initialization
88+
with _model_init_lock:
89+
self._tokenizer = AutoTokenizer.from_pretrained(
90+
artifacts_path, use_fast=True, padding_side="left"
91+
)
92+
self._model = SamOPTForCausalLM.from_pretrained(artifacts_path).to(
93+
self._device
94+
)
95+
self._model.eval()
8896

89-
self._image_processor = SamOptImageProcessor.from_pretrained(artifacts_path)
97+
self._image_processor = SamOptImageProcessor.from_pretrained(artifacts_path)
9098

9199
_log.debug("CodeFormulaModel settings: {}".format(self.info()))
92100

docling_ibm_models/document_figure_classifier_model/document_figure_classifier_predictor.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MIT
44
#
55
import logging
6+
import threading
67
from typing import List, Tuple, Union
78

89
import numpy as np
@@ -13,6 +14,9 @@
1314

1415
_log = logging.getLogger(__name__)
1516

17+
# Global lock for model initialization to prevent threading issues
18+
_model_init_lock = threading.Lock()
19+
1620

1721
class DocumentFigureClassifierPredictor:
1822
r"""
@@ -85,22 +89,23 @@ def __init__(
8589
if device == "cpu":
8690
torch.set_num_threads(self._num_threads)
8791

88-
model = AutoModelForImageClassification.from_pretrained(artifacts_path)
89-
self._model = model.to(device)
90-
self._model.eval()
91-
92-
self._image_processor = transforms.Compose(
93-
[
94-
transforms.Resize((224, 224)),
95-
transforms.ToTensor(),
96-
transforms.Normalize(
97-
mean=[0.485, 0.456, 0.406],
98-
std=[0.47853944, 0.4732864, 0.47434163],
99-
),
100-
]
101-
)
102-
103-
config = AutoConfig.from_pretrained(artifacts_path)
92+
with _model_init_lock:
93+
model = AutoModelForImageClassification.from_pretrained(artifacts_path)
94+
self._model = model.to(device)
95+
self._model.eval()
96+
97+
self._image_processor = transforms.Compose(
98+
[
99+
transforms.Resize((224, 224)),
100+
transforms.ToTensor(),
101+
transforms.Normalize(
102+
mean=[0.485, 0.456, 0.406],
103+
std=[0.47853944, 0.4732864, 0.47434163],
104+
),
105+
]
106+
)
107+
108+
config = AutoConfig.from_pretrained(artifacts_path)
104109

105110
self._classes = list(config.id2label.values())
106111
self._classes.sort()

docling_ibm_models/layoutmodel/layout_predictor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55
import logging
66
import os
7+
import threading
78
from collections.abc import Iterable
89
from typing import Set, Union
910

@@ -15,6 +16,9 @@
1516

1617
_log = logging.getLogger(__name__)
1718

19+
# Global lock for model initialization to prevent threading issues
20+
_model_init_lock = threading.Lock()
21+
1822

1923
class LayoutPredictor:
2024
"""
@@ -87,10 +91,13 @@ def __init__(
8791
processor_config = os.path.join(artifact_path, "preprocessor_config.json")
8892
model_config = os.path.join(artifact_path, "config.json")
8993
self._image_processor = RTDetrImageProcessor.from_json_file(processor_config)
90-
self._model = RTDetrForObjectDetection.from_pretrained(
91-
artifact_path, config=model_config
92-
).to(self._device)
93-
self._model.eval()
94+
95+
# Use lock to prevent threading issues during model initialization
96+
with _model_init_lock:
97+
self._model = RTDetrForObjectDetection.from_pretrained(
98+
artifact_path, config=model_config
99+
).to(self._device)
100+
self._model.eval()
94101

95102
_log.debug("LayoutPredictor settings: {}".format(self.info()))
96103

docling_ibm_models/tableformer/data_management/tf_predictor.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
import os
9+
import threading
910
from itertools import groupby
1011
from pathlib import Path
1112

@@ -35,6 +36,9 @@
3536

3637
logger = s.get_custom_logger(__name__, LOG_LEVEL)
3738

39+
# Global lock for model initialization to prevent threading issues
40+
_model_init_lock = threading.Lock()
41+
3842

3943
class bcolors:
4044
HEADER = "\033[95m"
@@ -175,34 +179,39 @@ def _load_model(self):
175179
"""
176180

177181
self._model_type = self._config["model"]["type"]
178-
model = TableModel04_rs(self._config, self._init_data, self._device)
179-
180-
if model is None:
181-
err_msg = "Not able to initiate a model for {}".format(self._model_type)
182-
self._log().error(err_msg)
183-
raise ValueError(err_msg)
184-
185-
self._remove_padding = False
186-
if self._model_type == "TableModel02":
187-
self._remove_padding = True
188-
189-
# Load model from safetensors
190-
save_dir = self._config["model"]["save_dir"]
191-
models_fn = glob.glob(f"{save_dir}/tableformer_*.safetensors")
192-
if not models_fn:
193-
err_msg = "Not able to find a model file for {}".format(self._model_type)
194-
self._log().error(err_msg)
195-
raise ValueError(err_msg)
196-
model_fn = models_fn[
197-
0
198-
] # Take the first tableformer safetensors file inside the save_dir
199-
missing, unexpected = load_model(model, model_fn, device=self._device)
200-
if missing or unexpected:
201-
err_msg = "Not able to load the model weights for {}".format(
202-
self._model_type
203-
)
204-
self._log().error(err_msg)
205-
raise ValueError(err_msg)
182+
183+
# Use lock to prevent threading issues during model initialization
184+
with _model_init_lock:
185+
model = TableModel04_rs(self._config, self._init_data, self._device)
186+
187+
if model is None:
188+
err_msg = "Not able to initiate a model for {}".format(self._model_type)
189+
self._log().error(err_msg)
190+
raise ValueError(err_msg)
191+
192+
self._remove_padding = False
193+
if self._model_type == "TableModel02":
194+
self._remove_padding = True
195+
196+
# Load model from safetensors
197+
save_dir = self._config["model"]["save_dir"]
198+
models_fn = glob.glob(f"{save_dir}/tableformer_*.safetensors")
199+
if not models_fn:
200+
err_msg = "Not able to find a model file for {}".format(
201+
self._model_type
202+
)
203+
self._log().error(err_msg)
204+
raise ValueError(err_msg)
205+
model_fn = models_fn[
206+
0
207+
] # Take the first tableformer safetensors file inside the save_dir
208+
missing, unexpected = load_model(model, model_fn, device=self._device)
209+
if missing or unexpected:
210+
err_msg = "Not able to load the model weights for {}".format(
211+
self._model_type
212+
)
213+
self._log().error(err_msg)
214+
raise ValueError(err_msg)
206215

207216
return model
208217

0 commit comments

Comments
 (0)