Skip to content

Commit fc2ab41

Browse files
committed
Add model downloader support
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
1 parent fa6e5fb commit fc2ab41

File tree

4 files changed

+65
-16
lines changed

4 files changed

+65
-16
lines changed

docling/cli/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class _AvailableModels(str, Enum):
4242
GRANITE_CHART_EXTRACTION = "granite_chart_extraction"
4343
RAPIDOCR = "rapidocr"
4444
EASYOCR = "easyocr"
45+
NEMOTRON_OCR = "nemotron_ocr"
4546

4647

4748
_default_models = [
@@ -123,6 +124,7 @@ def download(
123124
in to_download,
124125
with_rapidocr=_AvailableModels.RAPIDOCR in to_download,
125126
with_easyocr=_AvailableModels.EASYOCR in to_download,
127+
with_nemotron_ocr=_AvailableModels.NEMOTRON_OCR in to_download,
126128
)
127129

128130
if quiet:

docling/datamodel/pipeline_options.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ class NemotronOcrOptions(OcrOptions):
329329
Notes:
330330
Nemotron OCR does not expose runtime language selection through its public
331331
API. The `lang` field is kept only for compatibility with the shared OCR
332-
options interface.
332+
options interface. Use the pipeline-level `artifacts_path` to point to
333+
pre-downloaded checkpoint artifacts.
333334
"""
334335

335336
kind: ClassVar[Literal["nemotron-ocr"]] = "nemotron-ocr"
@@ -342,16 +343,6 @@ class NemotronOcrOptions(OcrOptions):
342343
)
343344
),
344345
] = []
345-
model_dir: Annotated[
346-
Optional[Path],
347-
Field(
348-
description=(
349-
"Optional directory containing the Nemotron OCR checkpoint files "
350-
"(`detector.pth`, `recognizer.pth`, `relational.pth`, `charset.txt`). "
351-
"If omitted, the upstream package downloads them from Hugging Face."
352-
)
353-
),
354-
] = None
355346
merge_level: Annotated[
356347
Literal["word", "sentence", "paragraph"],
357348
Field(

docling/models/stages/ocr/nemotron_ocr_model.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from docling.datamodel.settings import settings
2020
from docling.models.base_ocr_model import BaseOcrModel
21+
from docling.models.utils.hf_model_download import download_hf_model
2122
from docling.utils.accelerator_utils import decide_device
2223
from docling.utils.profiling import TimeRecorder
2324

@@ -53,6 +54,8 @@ class NemotronOcrPrediction(TypedDict):
5354

5455

5556
class NemotronOcrModel(BaseOcrModel):
57+
_repo_id = "nvidia/nemotron-ocr-v1"
58+
5659
def __init__(
5760
self,
5861
enabled: bool,
@@ -81,12 +84,10 @@ def __init__(
8184
"Python 3.12 and CUDA 13.x."
8285
) from exc
8386

84-
model_dir = (
85-
str(self.options.model_dir)
86-
if self.options.model_dir is not None
87-
else None
87+
model_dir = self._resolve_model_dir(artifacts_path=artifacts_path)
88+
self.reader = NemotronOCR(
89+
model_dir=None if model_dir is None else str(model_dir)
8890
)
89-
self.reader = NemotronOCR(model_dir=model_dir)
9091
# Install the storage workaround only at the upstream grid-sampler
9192
# boundary, keeping the rest of the Nemotron integration unchanged.
9293
self.reader.grid_sampler = _GridSamplerStorageWorkaround(
@@ -132,6 +133,51 @@ def _validate_runtime(cls, accelerator_options: AcceleratorOptions) -> None:
132133
f"reports CUDA {cuda_version!r}."
133134
)
134135

136+
@classmethod
137+
def _resolve_model_dir(cls, artifacts_path: Optional[Path]) -> Optional[Path]:
138+
if artifacts_path is None:
139+
return None
140+
141+
repo_cache_folder = cls._repo_id.replace("/", "--")
142+
if (artifacts_path / repo_cache_folder).exists():
143+
return artifacts_path / repo_cache_folder / "checkpoints"
144+
145+
available_dirs = []
146+
if artifacts_path.exists():
147+
available_dirs = sorted(
148+
path.name for path in artifacts_path.iterdir() if path.is_dir()
149+
)
150+
151+
raise FileNotFoundError(
152+
"Nemotron OCR artifacts not found in artifacts_path.\n"
153+
f"Expected location: {artifacts_path / repo_cache_folder / 'checkpoints'}\n"
154+
f"Available directories in {artifacts_path}: {available_dirs}\n"
155+
"Use `docling-tools models download nemotron_ocr` to pre-download "
156+
"the checkpoints or unset artifacts_path to allow the upstream "
157+
"package to download them."
158+
)
159+
160+
@staticmethod
161+
def download_models(
162+
local_dir: Optional[Path] = None,
163+
force: bool = False,
164+
progress: bool = False,
165+
) -> Path:
166+
if local_dir is None:
167+
local_dir = (
168+
settings.cache_dir
169+
/ "models"
170+
/ NemotronOcrModel._repo_id.replace("/", "--")
171+
)
172+
173+
local_dir.mkdir(parents=True, exist_ok=True)
174+
return download_hf_model(
175+
repo_id=NemotronOcrModel._repo_id,
176+
local_dir=local_dir,
177+
force=force,
178+
progress=progress,
179+
)
180+
135181
@staticmethod
136182
def _prediction_to_cell(
137183
prediction: NemotronOcrPrediction,

docling/utils/model_downloader.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from docling.models.stages.code_formula.code_formula_model import CodeFormulaModel
2121
from docling.models.stages.layout.layout_model import LayoutModel
2222
from docling.models.stages.ocr.easyocr_model import EasyOcrModel
23+
from docling.models.stages.ocr.nemotron_ocr_model import NemotronOcrModel
2324
from docling.models.stages.ocr.rapid_ocr_model import RapidOcrModel
2425
from docling.models.stages.picture_classifier.document_picture_classifier import (
2526
DocumentPictureClassifier,
@@ -55,6 +56,7 @@ def download_models(
5556
with_granite_chart_extraction: bool = False,
5657
with_rapidocr: bool = True,
5758
with_easyocr: bool = False,
59+
with_nemotron_ocr: bool = False,
5860
):
5961
if output_dir is None:
6062
output_dir = settings.cache_dir / "models"
@@ -189,4 +191,12 @@ def download_models(
189191
progress=progress,
190192
)
191193

194+
if with_nemotron_ocr:
195+
_log.info("Downloading nemotron OCR model...")
196+
NemotronOcrModel.download_models(
197+
local_dir=output_dir / NemotronOcrModel._repo_id.replace("/", "--"),
198+
force=force,
199+
progress=progress,
200+
)
201+
192202
return output_dir

0 commit comments

Comments
 (0)