Skip to content

Commit 59e2941

Browse files
nassarofficialAhmedMaxim Lysak
authored
feat: migration from onnx to pytorch script (#37)
--------- Signed-off-by: Ahmed <[email protected]> Signed-off-by: Maxim Lysak <[email protected]> Co-authored-by: Ahmed <[email protected]> Co-authored-by: Maxim Lysak <[email protected]>
1 parent 2a8be46 commit 59e2941

File tree

4 files changed

+705
-457
lines changed

4 files changed

+705
-457
lines changed

docling_ibm_models/layoutmodel/layout_predictor.py

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from typing import Union
88

99
import numpy as np
10-
import onnxruntime as ort
10+
import torch
11+
import torchvision.transforms as T
1112
from PIL import Image
1213

1314
MODEL_CHECKPOINT_FN = "model.pt"
@@ -16,14 +17,14 @@
1617

1718
class LayoutPredictor:
1819
r"""
19-
Document layout prediction using ONNX
20+
Document layout prediction using torch
2021
"""
2122

2223
def __init__(
2324
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
2425
):
2526
r"""
26-
Provide the artifact path that contains the LayoutModel ONNX file
27+
Provide the artifact path that contains the LayoutModel file
2728
2829
The number of threads is decided, in the following order, by:
2930
1. The init method parameter `num_threads`, if it is set.
@@ -38,13 +39,13 @@ def __init__(
3839
3940
Parameters
4041
----------
41-
artifact_path: Path for the model ONNX file.
42+
artifact_path: Path for the model torch file.
4243
num_threads: (Optional) Number of threads to run the inference.
4344
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
4445
4546
Raises
4647
------
47-
FileNotFoundError when the model's ONNX file is missing
48+
FileNotFoundError when the model's torch file is missing
4849
"""
4950
# Initialize classes map:
5051
self._classes_map = {
@@ -75,46 +76,27 @@ def __init__(
7576
self._threshold = 0.6 # Score threshold
7677
self._image_size = 640
7778
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
79+
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
80+
81+
# Model file
82+
self._torch_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
83+
if not os.path.isfile(self._torch_fn):
84+
raise FileNotFoundError("Missing torch file: {}".format(self._torch_fn))
7885

7986
# Get env vars
80-
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
8187
if num_threads is None:
8288
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
8389
self._num_threads = num_threads
8490

85-
# Decide the execution providers
86-
if (
87-
not self._use_cpu_only
88-
and "CUDAExecutionProvider" in ort.get_available_providers()
89-
):
90-
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
91-
else:
92-
providers = ["CPUExecutionProvider"]
93-
self._providers = providers
94-
95-
# Model ONNX file
96-
self._onnx_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
97-
if not os.path.isfile(self._onnx_fn):
98-
raise FileNotFoundError("Missing ONNX file: {}".format(self._onnx_fn))
99-
100-
# ONNX options
101-
self._options = ort.SessionOptions()
102-
self._options.intra_op_num_threads = self._num_threads
103-
self.sess = ort.InferenceSession(
104-
self._onnx_fn,
105-
sess_options=self._options,
106-
providers=self._providers,
107-
)
91+
self.model = torch.jit.load(self._torch_fn)
10892

10993
def info(self) -> dict:
11094
r"""
11195
Get information about the configuration of LayoutPredictor
11296
"""
11397
info = {
114-
"onnx_file": self._onnx_fn,
115-
"intra_op_num_threads": self._num_threads,
98+
"torch_file": self._torch_fn,
11699
"use_cpu_only": self._use_cpu_only,
117-
"providers": self._providers,
118100
"image_size": self._image_size,
119101
"threshold": self._threshold,
120102
}
@@ -147,33 +129,35 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
147129
raise TypeError("Not supported input image format")
148130

149131
w, h = page_img.size
150-
page_img = page_img.resize((self._image_size, self._image_size))
151-
page_data = np.array(page_img, dtype=np.uint8) / np.float32(255.0)
152-
page_data = np.expand_dims(np.transpose(page_data, axes=[2, 0, 1]), axis=0)
132+
orig_size = torch.tensor([w, h])[None]
153133

154-
# Predict
155-
labels, boxes, scores = self.sess.run(
156-
output_names=None,
157-
input_feed={
158-
"images": page_data,
159-
"orig_target_sizes": self._size,
160-
},
134+
transforms = T.Compose(
135+
[
136+
T.Resize((640, 640)),
137+
T.ToTensor(),
138+
]
161139
)
140+
img = transforms(page_img)[None]
141+
# Predict
142+
with torch.no_grad():
143+
labels, boxes, scores = self.model(img, orig_size)
162144

163145
# Yield output
164146
for label_idx, box, score in zip(labels[0], boxes[0], scores[0]):
165147
# Filter out blacklisted classes
166-
label = self._classes_map[label_idx]
148+
label_idx = int(label_idx.item())
149+
score = float(score.item())
150+
label = self._classes_map[label_idx + 1]
167151
if label in self._black_classes:
168152
continue
169153

170154
# Check against threshold
171155
if score > self._threshold:
172156
yield {
173-
"l": box[0] / self._image_size * w,
174-
"t": box[1] / self._image_size * h,
175-
"r": box[2] / self._image_size * w,
176-
"b": box[3] / self._image_size * h,
157+
"l": box[0],
158+
"t": box[1],
159+
"r": box[2],
160+
"b": box[3],
177161
"label": label,
178162
"confidence": score,
179163
}

0 commit comments

Comments
 (0)