Skip to content

Commit ce402d8

Browse files
benjats07qued
andauthored
Added ONNX model for detectron2 (#103)
* Added ONNX model for detectron2 * Update tests * Add comments * Make original detectron2 and onnx separate models * Deletes detectron2 dependency from Makefile --------- Co-authored-by: qued <[email protected]> Co-authored-by: Alan Bertl <[email protected]>
1 parent 5d10b1f commit ce402d8

File tree

5 files changed

+201
-1
lines changed

5 files changed

+201
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.4.5-dev0
2+
3+
* Added ONNX version of Detectron2
4+
15
## 0.4.4
26

37
* Fixed patches not being a package.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
from unittest.mock import patch
3+
4+
import pytest
5+
from PIL import Image
6+
7+
import unstructured_inference.models.detectron2onnx as detectron2
8+
import unstructured_inference.models.base as models
9+
10+
11+
class MockDetectron2ONNXLayoutModel:
12+
def __init__(self, *args, **kwargs):
13+
self.args = args
14+
self.kwargs = kwargs
15+
16+
def run(self, *args):
17+
return ([(1, 2, 3, 4)], [0], [0.818], [(4, 5)])
18+
19+
def get_inputs(self):
20+
class input_thing:
21+
name = "Bernard"
22+
23+
return [input_thing()]
24+
25+
26+
def test_load_default_model():
27+
with patch.object(
28+
detectron2.onnxruntime, "InferenceSession", new=MockDetectron2ONNXLayoutModel
29+
):
30+
model = models.get_model("detectron2_onnx")
31+
32+
assert isinstance(model.model, MockDetectron2ONNXLayoutModel)
33+
34+
35+
@pytest.mark.parametrize(("model_path", "label_map"), [("asdf", "diufs"), ("dfaw", "hfhfhfh")])
36+
def test_load_model(model_path, label_map):
37+
with patch.object(detectron2.onnxruntime, "InferenceSession", return_value=True):
38+
model = detectron2.UnstructuredDetectronONNXModel()
39+
model.initialize(model_path=model_path, label_map=label_map)
40+
args, _ = detectron2.onnxruntime.InferenceSession.call_args
41+
assert args == (model_path,)
42+
assert label_map == model.label_map
43+
44+
45+
def test_unstructured_detectron_model():
46+
model = detectron2.UnstructuredDetectronONNXModel()
47+
model.model = 1
48+
with patch.object(detectron2.UnstructuredDetectronONNXModel, "predict", return_value=[]):
49+
result = model(None)
50+
assert isinstance(result, list)
51+
assert len(result) == 0
52+
53+
54+
def test_inference():
55+
with patch.object(
56+
detectron2.onnxruntime, "InferenceSession", return_value=MockDetectron2ONNXLayoutModel()
57+
):
58+
model = detectron2.UnstructuredDetectronONNXModel()
59+
model.initialize(model_path="test_path", label_map={0: "test_class"})
60+
assert isinstance(model.model, MockDetectron2ONNXLayoutModel)
61+
with open(os.path.join("sample-docs", "receipt-sample.jpg"), mode="rb") as fp:
62+
image = Image.open(fp)
63+
image.load()
64+
elements = model(image)
65+
assert len(elements) == 1
66+
element = elements[0]
67+
(x1, y1), _, (x2, y2), _ = element.coordinates
68+
# NOTE(alan): The bbox coordinates get resized, so check their relative proportions
69+
assert x2 / x1 == pytest.approx(3.0) # x1 == 1, x2 == 3 before scaling
70+
assert y2 / y1 == pytest.approx(2.0) # y1 == 2, y2 == 4 before scaling
71+
assert element.type == "test_class"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.4" # pragma: no cover
1+
__version__ = "0.4.5-dev0" # pragma: no cover

unstructured_inference/models/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
MODEL_TYPES as DETECTRON2_MODEL_TYPES,
66
UnstructuredDetectronModel,
77
)
8+
from unstructured_inference.models.detectron2onnx import (
9+
MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES,
10+
UnstructuredDetectronONNXModel,
11+
)
812
from unstructured_inference.models.yolox import (
913
MODEL_TYPES as YOLOX_MODEL_TYPES,
1014
UnstructuredYoloXModel,
@@ -18,6 +22,9 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
1822
if model_name in DETECTRON2_MODEL_TYPES:
1923
model: UnstructuredModel = UnstructuredDetectronModel()
2024
model.initialize(**DETECTRON2_MODEL_TYPES[model_name])
25+
elif model_name in DETECTRON2_ONNX_MODEL_TYPES:
26+
model = UnstructuredDetectronONNXModel()
27+
model.initialize(**DETECTRON2_ONNX_MODEL_TYPES[model_name])
2128
elif model_name in YOLOX_MODEL_TYPES:
2229
model = UnstructuredYoloXModel()
2330
model.initialize(**YOLOX_MODEL_TYPES[model_name])
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from typing import Final, Optional, Union, Dict, List
2+
from pathlib import Path
3+
4+
from PIL import Image
5+
from huggingface_hub import hf_hub_download
6+
7+
from unstructured_inference.logger import logger
8+
from unstructured_inference.inference.layoutelement import LayoutElement
9+
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
10+
from unstructured_inference.utils import LazyDict, LazyEvaluateInfo
11+
import onnxruntime
12+
import numpy as np
13+
import cv2
14+
15+
16+
DEFAULT_LABEL_MAP: Final[Dict[int, str]] = {
17+
0: "Text",
18+
1: "Title",
19+
2: "List",
20+
3: "Table",
21+
4: "Figure",
22+
}
23+
24+
25+
# NOTE(alan): Entries are implemented as LazyDicts so that models aren't downloaded until they are
26+
# needed.
27+
MODEL_TYPES: Dict[Optional[str], LazyDict] = {
28+
"detectron2_onnx": LazyDict(
29+
model_path=LazyEvaluateInfo(
30+
hf_hub_download,
31+
"unstructuredio/detectron2_faster_rcnn_R_50_FPN_3x",
32+
"model.onnx",
33+
),
34+
label_map=DEFAULT_LABEL_MAP,
35+
confidence_threshold=0.8,
36+
),
37+
}
38+
39+
40+
class UnstructuredDetectronONNXModel(UnstructuredModel):
41+
"""Unstructured model wrapper for detectron2 ONNX model."""
42+
43+
# The model was trained and exported with this shape
44+
required_w = 800
45+
required_h = 1035
46+
47+
def predict(self, image: Image.Image) -> List[LayoutElement]:
48+
"""Makes a prediction using detectron2 model."""
49+
super().predict(image)
50+
51+
prepared_input = self.preprocess(image)
52+
bboxes, labels, confidence_scores, _ = self.model.run(None, prepared_input)
53+
input_w, input_h = image.size
54+
regions = self.postprocess(bboxes, labels, confidence_scores, input_w, input_h)
55+
56+
return regions
57+
58+
def initialize(
59+
self,
60+
model_path: Union[str, Path],
61+
label_map: Dict[int, str],
62+
confidence_threshold: Optional[float] = None,
63+
):
64+
"""Loads the detectron2 model using the specified parameters"""
65+
logger.info("Loading the Detectron2 layout model ...")
66+
self.model = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
67+
self.label_map = label_map
68+
if confidence_threshold is None:
69+
confidence_threshold = 0.5
70+
self.confidence_threshold = confidence_threshold
71+
72+
def preprocess(self, image: Image.Image) -> Dict[str, np.ndarray]:
73+
"""Process input image into required format for ingestion into the Detectron2 ONNX binary.
74+
This involves resizing to a fixed shape and converting to a specific numpy format."""
75+
# TODO (benjamin): check other shapes for inference
76+
img = np.array(image)
77+
# TODO (benjamin): We should use models.get_model() but currenly returns Detectron model
78+
session = self.model
79+
# onnx input expected
80+
# [3,1035,800]
81+
img = cv2.resize(
82+
img,
83+
(self.required_w, self.required_h),
84+
interpolation=cv2.INTER_LINEAR,
85+
).astype(np.float32)
86+
img = img.transpose(2, 0, 1)
87+
ort_inputs = {session.get_inputs()[0].name: img}
88+
return ort_inputs
89+
90+
def postprocess(
91+
self,
92+
bboxes: np.ndarray,
93+
labels: np.ndarray,
94+
confidence_scores: np.ndarray,
95+
input_w: float,
96+
input_h: float,
97+
) -> List[LayoutElement]:
98+
"""Process output into Unstructured class. Bounding box coordinates are converted to
99+
original image resolution."""
100+
regions = []
101+
width_conversion = input_w / self.required_w
102+
height_conversion = input_h / self.required_h
103+
for (x1, y1, x2, y2), label, conf in zip(bboxes, labels, confidence_scores):
104+
detected_class = self.label_map[int(label)]
105+
if conf >= self.confidence_threshold:
106+
region = LayoutElement(
107+
x1 * width_conversion,
108+
y1 * height_conversion,
109+
x2 * width_conversion,
110+
y2 * height_conversion,
111+
text=None,
112+
type=detected_class,
113+
)
114+
115+
regions.append(region)
116+
117+
regions.sort(key=lambda element: element.y1)
118+
return regions

0 commit comments

Comments
 (0)