Skip to content

Commit 04295b2

Browse files
feat!: New API for models initialization with accelerators parameters. Use HF implementation for LayoutPredictor. Migrate models to safetensors format. (#50)
Signed-off-by: Nikos Livathinos <[email protected]> Co-authored-by: Christoph Auer <[email protected]>
1 parent 33e0216 commit 04295b2

File tree

9 files changed

+1426
-276
lines changed

9 files changed

+1426
-276
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,9 @@ repos:
1919
entry: poetry lock --check
2020
pass_filenames: false
2121
language: system
22-
23-
# Ready to be enabled soon
24-
# - repo: local
25-
# hooks:
26-
# - id: system
27-
# name: flake8
28-
# entry: poetry run flake8 docling_ibm_models
29-
# pass_filenames: false
30-
# language: system
31-
# files: '\.py$'
32-
# - repo: local
33-
# hooks:
34-
# - id: system
35-
# name: MyPy
36-
# entry: poetry run mypy docling_ibm_models
37-
# pass_filenames: false
38-
# language: system
39-
# files: '\.py$'
22+
# - id: system
23+
# name: MyPy
24+
# entry: poetry run mypy docling_ibm_models
25+
# pass_filenames: false
26+
# language: system
27+
# files: '\.py$'

demo/demo_layout_predictor.py

Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,52 @@
1010
from pathlib import Path
1111

1212
import numpy as np
13-
from PIL import Image, ImageDraw
13+
import torch
1414
from huggingface_hub import snapshot_download
15+
from PIL import Image, ImageDraw, ImageFont
1516

1617
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
1718

1819

20+
def save_predictions(prefix: str, viz_dir: str, img_fn: str, img, predictions: dict):
21+
img_path = Path(img_fn)
22+
23+
image = img.copy()
24+
draw = ImageDraw.Draw(image)
25+
26+
predictions_filename = f"{prefix}_{img_path.stem}.txt"
27+
predictions_fn = os.path.join(viz_dir, predictions_filename)
28+
with open(predictions_fn, "w") as fd:
29+
for pred in predictions:
30+
bbox = [
31+
round(pred["l"], 2),
32+
round(pred["t"], 2),
33+
round(pred["r"], 2),
34+
round(pred["b"], 2),
35+
]
36+
label = pred["label"]
37+
confidence = round(pred["confidence"], 3)
38+
39+
# Save the predictions in txt file
40+
pred_txt = f"{prefix} {img_fn}: {label} - {bbox} - {confidence}\n"
41+
fd.write(pred_txt)
42+
43+
# Draw the bbox and label
44+
draw.rectangle(bbox, outline="orange")
45+
txt = f"{label}: {confidence}"
46+
draw.text(
47+
(bbox[0], bbox[1]), text=txt, font=ImageFont.load_default(), fill="blue"
48+
)
49+
50+
draw_filename = f"{prefix}_{img_path.name}"
51+
draw_fn = os.path.join(viz_dir, draw_filename)
52+
image.save(draw_fn)
53+
54+
1955
def demo(
2056
logger: logging.Logger,
2157
artifact_path: str,
58+
device: str,
2259
num_threads: int,
2360
img_dir: str,
2461
viz_dir: str,
@@ -30,58 +67,43 @@ def demo(
3067
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0)
3168
"""
3269
# Create the layout predictor
33-
lpredictor = LayoutPredictor(artifact_path, num_threads=num_threads)
34-
logger.info("LayoutPredictor settings: {}".format(lpredictor.info()))
70+
lpredictor = LayoutPredictor(artifact_path, device=device, num_threads=num_threads)
3571

3672
# Predict all test png images
73+
t0 = time.perf_counter()
74+
img_counter = 0
3775
for img_fn in Path(img_dir).rglob("*.png"):
76+
img_counter += 1
3877
logger.info("Predicting '%s'...", img_fn)
39-
start_t = time.time()
4078

4179
with Image.open(img_fn) as image:
4280
# Predict layout
81+
img_t0 = time.perf_counter()
4382
preds = list(lpredictor.predict(image))
44-
dt_ms = 1000 * (time.time() - start_t)
45-
logger.debug("Time elapsed for prediction(ms): %s", dt_ms)
46-
47-
# Draw predictions
48-
out_img = image.copy()
49-
draw = ImageDraw.Draw(out_img)
50-
51-
for i, pred in enumerate(preds):
52-
score = pred["confidence"]
53-
label = pred["label"]
54-
box = [
55-
round(pred["l"]),
56-
round(pred["t"]),
57-
round(pred["r"]),
58-
round(pred["b"]),
59-
]
60-
61-
# Draw bbox and label
62-
draw.rectangle(
63-
box,
64-
outline="red",
65-
)
66-
draw.text(
67-
(box[0], box[1]),
68-
text=str(label),
69-
fill="blue",
70-
)
71-
logger.info("%s: [label|score|bbox] = ['%s' | %s | %s]", i, label, score, box)
72-
73-
save_fn = os.path.join(viz_dir, os.path.basename(img_fn))
74-
out_img.save(save_fn)
75-
logger.info("Saving prediction visualization in: '%s'", save_fn)
83+
img_ms = 1000 * (time.perf_counter() - img_t0)
84+
logger.debug("Prediction(ms): {:.2f}".format(img_ms))
85+
86+
# Save predictions
87+
logger.info("Saving prediction visualization in: '%s'", viz_dir)
88+
save_predictions("ST", viz_dir, img_fn, image, preds)
89+
total_ms = 1000 * (time.perf_counter() - t0)
90+
avg_ms = (total_ms / img_counter) if img_counter > 0 else 0
91+
logger.info(
92+
"For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format(
93+
img_counter, total_ms, avg_ms
94+
)
95+
)
7696

7797

7898
def main(args):
7999
r""" """
80100
num_threads = int(args.num_threads) if args.num_threads is not None else None
101+
device = args.device.lower()
81102
img_dir = args.img_dir
82103
viz_dir = args.viz_dir
83104

84105
# Initialize logger
106+
logging.basicConfig(level=logging.DEBUG)
85107
logger = logging.getLogger("LayoutPredictor")
86108
logger.setLevel(logging.DEBUG)
87109
if not logger.hasHandlers():
@@ -96,11 +118,13 @@ def main(args):
96118
Path(viz_dir).mkdir(parents=True, exist_ok=True)
97119

98120
# Download models from HF
99-
download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.0.1")
100-
artifact_path = os.path.join(download_path, "model_artifacts/layout/beehive_v0.0.5_pt")
121+
download_path = snapshot_download(
122+
repo_id="ds4sd/docling-models", revision="v2.1.0"
123+
)
124+
artifact_path = os.path.join(download_path, "model_artifacts/layout")
101125

102126
# Test the LayoutPredictor
103-
demo(logger, artifact_path, num_threads, img_dir, viz_dir)
127+
demo(logger, artifact_path, device, num_threads, img_dir, viz_dir)
104128

105129

106130
if __name__ == "__main__":
@@ -109,7 +133,10 @@ def main(args):
109133
"""
110134
parser = argparse.ArgumentParser(description="Test the LayoutPredictor")
111135
parser.add_argument(
112-
"-n", "--num_threads", required=False, default=None, help="Number of threads"
136+
"-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]"
137+
)
138+
parser.add_argument(
139+
"-n", "--num_threads", required=False, default=4, help="Number of threads"
113140
)
114141
parser.add_argument(
115142
"-i",

docling_ibm_models/layoutmodel/layout_predictor.py

Lines changed: 74 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright IBM Corp. 2024 - 2024
33
# SPDX-License-Identifier: MIT
44
#
5+
import logging
56
import os
67
from collections.abc import Iterable
78
from typing import Union
@@ -10,38 +11,30 @@
1011
import torch
1112
import torchvision.transforms as T
1213
from PIL import Image
14+
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
1315

14-
MODEL_CHECKPOINT_FN = "model.pt"
15-
DEFAULT_NUM_THREADS = 4
16+
_log = logging.getLogger(__name__)
1617

1718

1819
class LayoutPredictor:
19-
r"""
20-
Document layout prediction using torch
20+
"""
21+
Document layout prediction using safe tensors
2122
"""
2223

2324
def __init__(
24-
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
25+
self,
26+
artifact_path: str,
27+
device: str = "cpu",
28+
num_threads: int = 4,
2529
):
26-
r"""
30+
"""
2731
Provide the artifact path that contains the LayoutModel file
2832
29-
The number of threads is decided, in the following order, by:
30-
1. The init method parameter `num_threads`, if it is set.
31-
2. The envvar "OMP_NUM_THREADS", if it is set.
32-
3. The default value DEFAULT_NUM_THREADS.
33-
34-
The execution provided is decided, in the following order:
35-
1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
36-
it uses the "CPUExecutionProvider".
37-
3. Otherwise if the "CUDAExecutionProvider" is present, use:
38-
["CUDAExecutionProvider", "CPUExecutionProvider"]:
39-
4033
Parameters
4134
----------
4235
artifact_path: Path for the model torch file.
43-
num_threads: (Optional) Number of threads to run the inference.
44-
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
36+
device: (Optional) device to run the inference.
37+
num_threads: (Optional) Number of threads to run the inference if device = 'cpu'
4538
4639
Raises
4740
------
@@ -70,40 +63,51 @@ def __init__(
7063
}
7164

7265
# Blacklisted classes
73-
self._black_classes = set(["Form", "Key-Value Region"])
66+
self._black_classes = set() # ["Form", "Key-Value Region"])
7467

7568
# Set basic params
76-
self._threshold = 0.6 # Score threshold
69+
self._threshold = 0.3 # Score threshold
7770
self._image_size = 640
7871
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
79-
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
8072

81-
# Model file
82-
self._torch_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
83-
if not os.path.isfile(self._torch_fn):
84-
raise FileNotFoundError("Missing torch file: {}".format(self._torch_fn))
85-
86-
# Get env vars
87-
if num_threads is None:
88-
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
73+
# Set number of threads for CPU
74+
self._device = torch.device(device)
8975
self._num_threads = num_threads
76+
if device == "cpu":
77+
torch.set_num_threads(self._num_threads)
78+
79+
# Model file and configurations
80+
self._st_fn = os.path.join(artifact_path, "model.safetensors")
81+
if not os.path.isfile(self._st_fn):
82+
raise FileNotFoundError("Missing safe tensors file: {}".format(self._st_fn))
9083

91-
self.model = torch.jit.load(self._torch_fn)
84+
# Load model and move to device
85+
processor_config = os.path.join(artifact_path, "preprocessor_config.json")
86+
model_config = os.path.join(artifact_path, "config.json")
87+
self._image_processor = RTDetrImageProcessor.from_json_file(processor_config)
88+
self._model = RTDetrForObjectDetection.from_pretrained(
89+
artifact_path, config=model_config
90+
).to(self._device)
91+
self._model.eval()
92+
93+
_log.debug("LayoutPredictor settings: {}".format(self.info()))
9294

9395
def info(self) -> dict:
94-
r"""
96+
"""
9597
Get information about the configuration of LayoutPredictor
9698
"""
9799
info = {
98-
"torch_file": self._torch_fn,
99-
"use_cpu_only": self._use_cpu_only,
100+
"safe_tensors_file": self._st_fn,
101+
"device": self._device.type,
102+
"num_threads": self._num_threads,
100103
"image_size": self._image_size,
101104
"threshold": self._threshold,
102105
}
103106
return info
104107

108+
@torch.inference_mode()
105109
def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
106-
r"""
110+
"""
107111
Predict bounding boxes for a given image.
108112
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
109113
[left, top, right, bottom]
@@ -128,40 +132,44 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
128132
else:
129133
raise TypeError("Not supported input image format")
130134

135+
resize = {"height": self._image_size, "width": self._image_size}
136+
inputs = self._image_processor(
137+
images=page_img,
138+
return_tensors="pt",
139+
size=resize,
140+
).to(self._device)
141+
outputs = self._model(**inputs)
142+
results = self._image_processor.post_process_object_detection(
143+
outputs,
144+
target_sizes=torch.tensor([page_img.size[::-1]]),
145+
threshold=self._threshold,
146+
)
147+
131148
w, h = page_img.size
132-
orig_size = torch.tensor([w, h])[None]
133149

134-
transforms = T.Compose(
135-
[
136-
T.Resize((640, 640)),
137-
T.ToTensor(),
138-
]
139-
)
140-
img = transforms(page_img)[None]
141-
# Predict
142-
with torch.no_grad():
143-
labels, boxes, scores = self.model(img, orig_size)
150+
result = results[0]
151+
for score, label_id, box in zip(
152+
result["scores"], result["labels"], result["boxes"]
153+
):
154+
score = float(score.item())
155+
156+
label_id = int(label_id.item()) + 1 # Advance the label_id
157+
label_str = self._classes_map[label_id]
144158

145-
# Yield output
146-
for label_idx, box, score in zip(labels[0], boxes[0], scores[0]):
147159
# Filter out blacklisted classes
148-
label_idx = int(label_idx.item())
149-
score = float(score.item())
150-
label = self._classes_map[label_idx + 1]
151-
if label in self._black_classes:
160+
if label_str in self._black_classes:
152161
continue
153162

154-
# Check against threshold
155-
if score > self._threshold:
156-
l = min(w, max(0, box[0]))
157-
t = min(h, max(0, box[1]))
158-
r = min(w, max(0, box[2]))
159-
b = min(h, max(0, box[3]))
160-
yield {
161-
"l": l,
162-
"t": t,
163-
"r": r,
164-
"b": b,
165-
"label": label,
166-
"confidence": score,
167-
}
163+
bbox_float = [float(b.item()) for b in box]
164+
l = min(w, max(0, bbox_float[0]))
165+
t = min(h, max(0, bbox_float[1]))
166+
r = min(w, max(0, bbox_float[2]))
167+
b = min(h, max(0, bbox_float[3]))
168+
yield {
169+
"l": l,
170+
"t": t,
171+
"r": r,
172+
"b": b,
173+
"label": label_str,
174+
"confidence": score,
175+
}

0 commit comments

Comments
 (0)