Skip to content

Commit 505fbf4

Browse files
feat: Refactor the LayoutPredictor to support all layout models (#121)
Signed-off-by: Nikos Livathinos <[email protected]>
1 parent a9feac8 commit 505fbf4

File tree

4 files changed

+131
-62
lines changed

4 files changed

+131
-62
lines changed

demo/demo_layout_predictor.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
import time
1010
from pathlib import Path
11-
11+
from typing import Any, Dict, List
1212
import numpy as np
1313
import torch
1414
from huggingface_hub import snapshot_download
@@ -17,7 +17,9 @@
1717
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
1818

1919

20-
def save_predictions(prefix: str, viz_dir: str, img_fn: str, img, predictions: dict):
20+
def save_predictions(
21+
prefix: str, viz_dir: str, img_fn: Path, img, predictions: List[Dict[str, Any]]
22+
):
2123
img_path = Path(img_fn)
2224

2325
image = img.copy()
@@ -37,7 +39,7 @@ def save_predictions(prefix: str, viz_dir: str, img_fn: str, img, predictions: d
3739
confidence = round(pred["confidence"], 3)
3840

3941
# Save the predictions in txt file
40-
pred_txt = f"{prefix} {img_fn}: {label} - {bbox} - {confidence}\n"
42+
pred_txt = f"{prefix} {str(img_fn)}: {label} - {bbox} - {confidence}\n"
4143
fd.write(pred_txt)
4244

4345
# Draw the bbox and label
@@ -59,6 +61,7 @@ def demo(
5961
num_threads: int,
6062
img_dir: str,
6163
viz_dir: str,
64+
threshold: float,
6265
):
6366
r"""
6467
Apply LayoutPredictor on the input image directory
@@ -67,7 +70,7 @@ def demo(
6770
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0)
6871
"""
6972
# Create the layout predictor
70-
lpredictor = LayoutPredictor(artifact_path, device=device, num_threads=num_threads)
73+
predictor = LayoutPredictor(artifact_path, device=device, num_threads=num_threads, base_threshold=threshold)
7174

7275
# Predict all test png images
7376
t0 = time.perf_counter()
@@ -79,7 +82,7 @@ def demo(
7982
with Image.open(img_fn) as image:
8083
# Predict layout
8184
img_t0 = time.perf_counter()
82-
preds = list(lpredictor.predict(image))
85+
preds: List[Dict[str, Any]] = list(predictor.predict(image))
8386
img_ms = 1000 * (time.perf_counter() - img_t0)
8487
logger.debug("Prediction(ms): {:.2f}".format(img_ms))
8588

@@ -97,10 +100,12 @@ def demo(
97100

98101
def main(args):
99102
r""" """
100-
num_threads = int(args.num_threads) if args.num_threads is not None else None
103+
num_threads = int(args.num_threads) if args.num_threads is not None else 4
101104
device = args.device.lower()
102105
img_dir = args.img_dir
103106
viz_dir = args.viz_dir
107+
hugging_face_repo = args.hugging_face_repo
108+
threshold = float(args.threshold)
104109

105110
# Initialize logger
106111
logging.basicConfig(level=logging.DEBUG)
@@ -118,20 +123,36 @@ def main(args):
118123
Path(viz_dir).mkdir(parents=True, exist_ok=True)
119124

120125
# Download models from HF
121-
download_path = snapshot_download(
122-
repo_id="ds4sd/docling-models", revision="v2.1.0"
123-
)
124-
artifact_path = os.path.join(download_path, "model_artifacts/layout")
126+
download_path = snapshot_download(repo_id=hugging_face_repo)
125127

126128
# Test the LayoutPredictor
127-
demo(logger, artifact_path, device, num_threads, img_dir, viz_dir)
129+
demo(logger, download_path, device, num_threads, img_dir, viz_dir, threshold)
128130

129131

130132
if __name__ == "__main__":
131133
r"""
132134
python -m demo.demo_layout_predictor -i <images_dir>
133135
"""
134136
parser = argparse.ArgumentParser(description="Test the LayoutPredictor")
137+
138+
supported_hf_repos = [
139+
"ds4sd/docling-layout-old",
140+
"ds4sd/docling-layout-heron",
141+
"ds4sd/docling-layout-heron-101",
142+
"ds4sd/docling-layout-egret-medium",
143+
"ds4sd/docling-layout-egret-large",
144+
"ds4sd/docling-layout-egret-xlarge",
145+
]
146+
parser.add_argument(
147+
"-r",
148+
"--hugging-face-repo",
149+
required=False,
150+
default="ds4sd/docling-layout-old",
151+
help=f"The hugging face repo id: [{', '.join(supported_hf_repos)}]",
152+
)
153+
parser.add_argument(
154+
"-t", "--threshold", required=False, default=0.3, help="Threshold for the LayoutPredictor"
155+
)
135156
parser.add_argument(
136157
"-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]"
137158
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Dict
2+
3+
4+
class LayoutLabels:
5+
r"""Single point of reference for the layout labels"""
6+
7+
def __init__(self) -> None:
8+
r""" """
9+
# Canonical classes originating in DLNv2
10+
self._canonical: Dict[int, str] = {
11+
# DLNv1 and DLNv2
12+
0: "Caption",
13+
1: "Footnote",
14+
2: "Formula",
15+
3: "List-item",
16+
4: "Page-footer",
17+
5: "Page-header",
18+
6: "Picture",
19+
7: "Section-header",
20+
8: "Table",
21+
9: "Text",
22+
10: "Title",
23+
# DLNv2 only
24+
11: "Document Index",
25+
12: "Code",
26+
13: "Checkbox-Selected",
27+
14: "Checkbox-Unselected",
28+
15: "Form",
29+
16: "Key-Value Region",
30+
}
31+
self._inverse_canonical: Dict[str, int] = {
32+
label: class_id for class_id, label in self._canonical.items()
33+
}
34+
35+
# Shifted canonical classes with background in 0
36+
self._shifted_canonical: Dict[int, str] = {0: "Background"}
37+
for k, v in self._canonical.items():
38+
self._shifted_canonical[k + 1] = v
39+
self._inverse_shifted_canonical: Dict[str, int] = {
40+
label: class_id for class_id, label in self._shifted_canonical.items()
41+
}
42+
43+
def canonical_categories(self) -> Dict[int, str]:
44+
return self._canonical
45+
46+
def canonical_to_int(self) -> Dict[str, int]:
47+
return self._inverse_canonical
48+
49+
def shifted_canonical_categories(self) -> Dict[int, str]:
50+
return self._shifted_canonical
51+
52+
def shifted_canonical_to_int(self) -> Dict[str, int]:
53+
return self._inverse_shifted_canonical

docling_ibm_models/layoutmodel/layout_predictor.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import os
77
import threading
88
from collections.abc import Iterable
9-
from typing import Set, Union
9+
from typing import Dict, List, Set, Union
1010

1111
import numpy as np
1212
import torch
13-
import torchvision.transforms as T
1413
from PIL import Image
15-
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
14+
from torch import Tensor
15+
from transformers import AutoModelForObjectDetection, RTDetrImageProcessor
16+
17+
from docling_ibm_models.layoutmodel.labels import LayoutLabels
1618

1719
_log = logging.getLogger(__name__)
1820

@@ -46,70 +48,67 @@ def __init__(
4648
------
4749
FileNotFoundError when the model's torch file is missing
4850
"""
49-
# Initialize classes map:
50-
self._classes_map = {
51-
0: "background",
52-
1: "Caption",
53-
2: "Footnote",
54-
3: "Formula",
55-
4: "List-item",
56-
5: "Page-footer",
57-
6: "Page-header",
58-
7: "Picture",
59-
8: "Section-header",
60-
9: "Table",
61-
10: "Text",
62-
11: "Title",
63-
12: "Document Index",
64-
13: "Code",
65-
14: "Checkbox-Selected",
66-
15: "Checkbox-Unselected",
67-
16: "Form",
68-
17: "Key-Value Region",
69-
}
70-
7151
# Blacklisted classes
7252
self._black_classes = blacklist_classes # set(["Form", "Key-Value Region"])
7353

54+
# Canonical classes
55+
self._labels = LayoutLabels()
56+
7457
# Set basic params
7558
self._threshold = base_threshold # Score threshold
76-
self._image_size = 640
77-
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
7859

7960
# Set number of threads for CPU
8061
self._device = torch.device(device)
8162
self._num_threads = num_threads
8263
if device == "cpu":
8364
torch.set_num_threads(self._num_threads)
8465

85-
# Model file and configurations
66+
# Load model file and configurations
67+
self._processor_config = os.path.join(artifact_path, "preprocessor_config.json")
68+
self._model_config = os.path.join(artifact_path, "config.json")
8669
self._st_fn = os.path.join(artifact_path, "model.safetensors")
8770
if not os.path.isfile(self._st_fn):
8871
raise FileNotFoundError("Missing safe tensors file: {}".format(self._st_fn))
72+
if not os.path.isfile(self._processor_config):
73+
raise FileNotFoundError(
74+
f"Missing processor config file: {self._processor_config}"
75+
)
76+
if not os.path.isfile(self._model_config):
77+
raise FileNotFoundError(f"Missing model config file: {self._model_config}")
8978

9079
# Load model and move to device
91-
processor_config = os.path.join(artifact_path, "preprocessor_config.json")
92-
model_config = os.path.join(artifact_path, "config.json")
93-
self._image_processor = RTDetrImageProcessor.from_json_file(processor_config)
80+
self._image_processor = RTDetrImageProcessor.from_json_file(
81+
self._processor_config
82+
)
9483

9584
# Use lock to prevent threading issues during model initialization
9685
with _model_init_lock:
97-
self._model = RTDetrForObjectDetection.from_pretrained(
98-
artifact_path, config=model_config
86+
self._model = AutoModelForObjectDetection.from_pretrained(
87+
artifact_path, config=self._model_config
9988
).to(self._device)
10089
self._model.eval()
10190

91+
# Set classes map
92+
self._model_name = type(self._model).__name__
93+
if self._model_name == "RTDetrForObjectDetection":
94+
self._classes_map = self._labels.shifted_canonical_categories()
95+
self._label_offset = 1
96+
else:
97+
self._classes_map = self._labels.canonical_categories()
98+
self._label_offset = 0
99+
102100
_log.debug("LayoutPredictor settings: {}".format(self.info()))
103101

104102
def info(self) -> dict:
105103
"""
106104
Get information about the configuration of LayoutPredictor
107105
"""
108106
info = {
107+
"model_name": self._model_name,
109108
"safe_tensors_file": self._st_fn,
110109
"device": self._device.type,
111110
"num_threads": self._num_threads,
112-
"image_size": self._image_size,
111+
"image_size": self._image_processor.size,
113112
"threshold": self._threshold,
114113
}
115114
return info
@@ -141,28 +140,27 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
141140
else:
142141
raise TypeError("Not supported input image format")
143142

144-
resize = {"height": self._image_size, "width": self._image_size}
145-
inputs = self._image_processor(
146-
images=page_img,
147-
return_tensors="pt",
148-
size=resize,
149-
).to(self._device)
143+
target_sizes = torch.tensor([page_img.size[::-1]])
144+
inputs = self._image_processor(images=[page_img], return_tensors="pt").to(
145+
self._device
146+
)
150147
outputs = self._model(**inputs)
151-
results = self._image_processor.post_process_object_detection(
152-
outputs,
153-
target_sizes=torch.tensor([page_img.size[::-1]]),
154-
threshold=self._threshold,
148+
results: List[Dict[str, Tensor]] = (
149+
self._image_processor.post_process_object_detection(
150+
outputs,
151+
target_sizes=target_sizes,
152+
threshold=self._threshold,
153+
)
155154
)
156155

157156
w, h = page_img.size
158-
159157
result = results[0]
160158
for score, label_id, box in zip(
161159
result["scores"], result["labels"], result["boxes"]
162160
):
163161
score = float(score.item())
164162

165-
label_id = int(label_id.item()) + 1 # Advance the label_id
163+
label_id = int(label_id.item()) + self._label_offset
166164
label_str = self._classes_map[label_id]
167165

168166
# Filter out blacklisted classes

tests/test_layout_predictor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
#
55
import os
66
import json
7-
from pathlib import Path
87

9-
import torch
108
import numpy as np
119
import pytest
1210
from huggingface_hub import snapshot_download
13-
from PIL import Image, ImageDraw, ImageFont
11+
from PIL import Image
1412

1513
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
1614

@@ -35,8 +33,7 @@ def init() -> dict:
3533
}
3634

3735
# Download models from HF
38-
download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.1.0")
39-
artifact_path = os.path.join(download_path, "model_artifacts/layout")
36+
artifact_path = snapshot_download(repo_id="ds4sd/docling-layout-old")
4037

4138
# Add the missing config keys
4239
init["artifact_path"] = artifact_path

0 commit comments

Comments
 (0)