Skip to content

Commit c696549

Browse files
authored
fix: address deprecation warnings of dependencies (#2237)
* switch to dtype instead of torch_dtype Signed-off-by: Michele Dolfi <[email protected]> * set __check_model__ to avoid deprecation warnings Signed-off-by: Michele Dolfi <[email protected]> * remove dataloaders warnings in easyocr Signed-off-by: Michele Dolfi <[email protected]> * suppress with option Signed-off-by: Michele Dolfi <[email protected]> --------- Signed-off-by: Michele Dolfi <[email protected]>
1 parent f8cc545 commit c696549

File tree

6 files changed

+27
-12
lines changed

6 files changed

+27
-12
lines changed

docling/datamodel/pipeline_options.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class EasyOcrOptions(OcrOptions):
135135
recog_network: Optional[str] = "standard"
136136
download_enabled: bool = True
137137

138+
suppress_mps_warnings: bool = True
139+
138140
model_config = ConfigDict(
139141
extra="forbid",
140142
protected_namespaces=(),

docling/models/easyocr_model.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,17 @@ def __init__(
7878
download_enabled = False
7979
model_storage_directory = str(artifacts_path / self._model_repo_folder)
8080

81-
self.reader = easyocr.Reader(
82-
lang_list=self.options.lang,
83-
gpu=use_gpu,
84-
model_storage_directory=model_storage_directory,
85-
recog_network=self.options.recog_network,
86-
download_enabled=download_enabled,
87-
verbose=False,
88-
)
81+
with warnings.catch_warnings():
82+
if self.options.suppress_mps_warnings:
83+
warnings.filterwarnings("ignore", message=".*pin_memory.*MPS.*")
84+
self.reader = easyocr.Reader(
85+
lang_list=self.options.lang,
86+
gpu=use_gpu,
87+
model_storage_directory=model_storage_directory,
88+
recog_network=self.options.recog_network,
89+
download_enabled=download_enabled,
90+
verbose=False,
91+
)
8992

9093
@staticmethod
9194
def download_models(
@@ -147,7 +150,14 @@ def __call__(
147150
scale=self.scale, cropbox=ocr_rect
148151
)
149152
im = numpy.array(high_res_image)
150-
result = self.reader.readtext(im)
153+
154+
with warnings.catch_warnings():
155+
if self.options.suppress_mps_warnings:
156+
warnings.filterwarnings(
157+
"ignore", message=".*pin_memory.*MPS.*"
158+
)
159+
160+
result = self.reader.readtext(im)
151161

152162
del high_res_image
153163
del im

docling/models/picture_description_vlm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
self.model = AutoModelForImageTextToText.from_pretrained(
6868
artifacts_path,
6969
device_map=self.device,
70-
torch_dtype=torch.bfloat16,
70+
dtype=torch.bfloat16,
7171
_attn_implementation=(
7272
"flash_attention_2"
7373
if self.device.startswith("cuda")

docling/models/vlm_models_inline/hf_transformers_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
self.vlm_model = model_cls.from_pretrained(
113113
artifacts_path,
114114
device_map=self.device,
115-
torch_dtype=self.vlm_options.torch_dtype,
115+
dtype=self.vlm_options.torch_dtype,
116116
_attn_implementation=(
117117
"flash_attention_2"
118118
if self.device.startswith("cuda")

docling/models/vlm_models_inline/nuextract_transformers_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144
self.vlm_model = AutoModelForImageTextToText.from_pretrained(
145145
artifacts_path,
146146
device_map=self.device,
147-
torch_dtype=self.vlm_options.torch_dtype,
147+
dtype=self.vlm_options.torch_dtype,
148148
_attn_implementation=(
149149
"flash_attention_2"
150150
if self.device.startswith("cuda")

docling/pipeline/extraction_vlm_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def _serialize_template(self, template: ExtractionTemplateType) -> str:
194194
class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore
195195
__use_examples__ = True # prefer Field(examples=...) when present
196196
__use_defaults__ = True # use field defaults instead of random values
197+
__check_model__ = (
198+
True # setting the value to avoid deprecation warnings
199+
)
197200

198201
return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore
199202
else:

0 commit comments

Comments
 (0)