Skip to content

Commit 04b2e6f

Browse files
Feat/detectron2 mask rcnn x101 (#149)
* Update CHANGELOG * Added alternative model for detectron2 * Added script for showing results * Corrected index at getting output from detectron2 * Version sync * Detectron2 with architecture mask-rcnn-x101 is now the default model * Change default values during test for detectron2onnx * Format changes * Changed type of stored model_path Signed-off-by: Benjamin Torres <[email protected]> Co-authored-by: Alan Bertl <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps-dev): Bump ipython from 8.12.2 to 8.14.0 in /requirements * chore(deps): Bump transformers from 4.29.2 to 4.30.2 in /requirements (#142) * chore(deps): Bump transformers from 4.29.2 to 4.30.2 in /requirements * chore(deps): Bump opencv-python from 4.7.0.72 to 4.8.0.74 in /requirements (#137) * chore(deps): Bump opencv-python in /requirements Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Benjamin Torres <[email protected]>
1 parent 8913620 commit 04b2e6f

File tree

6 files changed

+46
-26
lines changed

6 files changed

+46
-26
lines changed

CHANGELOG.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1-
## 0.5.8-dev1
1+
## 0.5.8-dev2
22

3-
* Cache named models that have been lodaed
3+
* Add alternative architecture for detectron2
4+
* Updates:
5+
6+
| Library | From | To |
7+
|---------------|-----------|----------|
8+
| transformers | 4.29.2 | 4.30.2 |
9+
| opencv-python | 4.7.0.72 | 4.8.0.74 |
10+
| ipython | 8.12.2 | 8.14.0 |
11+
12+
* Cache named models that have been loaded
413

514
## 0.5.7
615

requirements/dev.txt

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@
66
#
77
anyio==3.7.1
88
# via
9-
# -c requirements/test.txt
9+
# -c test.txt
1010
# jupyter-server
11-
appnope==0.1.3
12-
# via
13-
# ipykernel
14-
# ipython
1511
argon2-cffi==21.3.0
1612
# via jupyter-server
1713
argon2-cffi-bindings==21.2.0
@@ -43,7 +39,7 @@ certifi==2023.7.22
4339
# requests
4440
cffi==1.15.1
4541
# via
46-
# -c requirements/base.txt
42+
# -c base.txt
4743
# argon2-cffi-bindings
4844
charset-normalizer==3.2.0
4945
# via
@@ -52,7 +48,7 @@ charset-normalizer==3.2.0
5248
# requests
5349
click==8.1.6
5450
# via
55-
# -c requirements/test.txt
51+
# -c test.txt
5652
# pip-tools
5753
comm==0.1.3
5854
# via ipykernel
@@ -74,8 +70,8 @@ fqdn==1.5.1
7470
# via jsonschema
7571
idna==3.4
7672
# via
77-
# -c requirements/base.txt
78-
# -c requirements/test.txt
73+
# -c base.txt
74+
# -c test.txt
7975
# anyio
8076
# jsonschema
8177
# requests
@@ -100,9 +96,9 @@ ipykernel==6.25.0
10096
# jupyter-console
10197
# jupyterlab
10298
# qtconsole
103-
ipython==8.12.2
99+
ipython==8.14.0
104100
# via
105-
# -r requirements/dev.in
101+
# -r dev.in
106102
# ipykernel
107103
# ipywidgets
108104
# jupyter-console
@@ -116,7 +112,7 @@ jedi==0.18.2
116112
# via ipython
117113
jinja2==3.1.2
118114
# via
119-
# -c requirements/base.txt
115+
# -c base.txt
120116
# jupyter-server
121117
# jupyterlab
122118
# jupyterlab-server
@@ -179,7 +175,7 @@ jupyterlab-widgets==3.0.8
179175
# via ipywidgets
180176
markupsafe==2.1.3
181177
# via
182-
# -c requirements/base.txt
178+
# -c base.txt
183179
# jinja2
184180
# nbconvert
185181
matplotlib-inline==0.1.6
@@ -211,8 +207,8 @@ overrides==7.3.1
211207
# via jupyter-server
212208
packaging==23.1
213209
# via
214-
# -c requirements/base.txt
215-
# -c requirements/test.txt
210+
# -c base.txt
211+
# -c test.txt
216212
# build
217213
# ipykernel
218214
# jupyter-server
@@ -235,7 +231,7 @@ pkgutil-resolve-name==1.3.10
235231
# via jsonschema
236232
platformdirs==3.9.1
237233
# via
238-
# -c requirements/test.txt
234+
# -c test.txt
239235
# jupyter-core
240236
prometheus-client==0.17.1
241237
# via jupyter-server
@@ -253,7 +249,7 @@ pure-eval==0.2.2
253249
# via stack-data
254250
pycparser==2.21
255251
# via
256-
# -c requirements/base.txt
252+
# -c base.txt
257253
# cffi
258254
pygments==2.15.1
259255
# via
@@ -315,14 +311,14 @@ send2trash==1.8.2
315311
# via jupyter-server
316312
six==1.16.0
317313
# via
318-
# -c requirements/base.txt
314+
# -c base.txt
319315
# asttokens
320316
# bleach
321317
# python-dateutil
322318
# rfc3339-validator
323319
sniffio==1.3.0
324320
# via
325-
# -c requirements/test.txt
321+
# -c test.txt
326322
# anyio
327323
soupsieve==2.4.1
328324
# via beautifulsoup4

test_unstructured_inference/models/test_detectron2onnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, *args, **kwargs):
1414
self.kwargs = kwargs
1515

1616
def run(self, *args):
17-
return ([(1, 2, 3, 4)], [0], [0.818], [(4, 5)])
17+
return ([(1, 2, 3, 4)], [0], [(4, 5)], [0.818])
1818

1919
def get_inputs(self):
2020
class input_thing:
@@ -30,7 +30,7 @@ def test_load_default_model(monkeypatch):
3030
"InferenceSession",
3131
new=MockDetectron2ONNXLayoutModel,
3232
):
33-
model = models.get_model("detectron2_onnx")
33+
model = models.get_model("detectron2_mask_rcnn")
3434

3535
assert isinstance(model.model, MockDetectron2ONNXLayoutModel)
3636

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.8-dev1" # pragma: no cover
1+
__version__ = "0.5.8-dev2" # pragma: no cover

unstructured_inference/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
UnstructuredYoloXModel,
2424
)
2525

26-
DEFAULT_MODEL = "detectron2_onnx"
26+
DEFAULT_MODEL = "detectron2_mask_rcnn"
2727

2828
models: Dict[str, UnstructuredModel] = {}
2929

unstructured_inference/models/detectron2onnx.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@
3737
label_map=DEFAULT_LABEL_MAP,
3838
confidence_threshold=0.8,
3939
),
40+
"detectron2_mask_rcnn": LazyDict(
41+
model_path=LazyEvaluateInfo(
42+
hf_hub_download,
43+
"unstructuredio/detectron2_mask_rcnn_X_101_32x8d_FPN_3x",
44+
"model.onnx",
45+
),
46+
label_map=DEFAULT_LABEL_MAP,
47+
confidence_threshold=0.8,
48+
),
4049
}
4150

4251

@@ -53,7 +62,12 @@ def predict(self, image: Image.Image) -> List[LayoutElement]:
5362

5463
prepared_input = self.preprocess(image)
5564
try:
56-
bboxes, labels, confidence_scores, _ = self.model.run(None, prepared_input)
65+
result = self.model.run(None, prepared_input)
66+
bboxes = result[0]
67+
labels = result[1]
68+
# Previous model detectron2_onnx stored confidence scores at index 2,
69+
# bigger model stores it at index 3
70+
confidence_scores = result[2] if "R_50" in self.model_path else result[3]
5771
except onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException:
5872
logger_onnx.debug(
5973
"Ignoring runtime error from onnx (likely due to encountering blank page).",
@@ -72,6 +86,7 @@ def initialize(
7286
):
7387
"""Loads the detectron2 model using the specified parameters"""
7488
logger.info("Loading the Detectron2 layout model ...")
89+
self.model_path = str(model_path)
7590
self.model = onnxruntime.InferenceSession(
7691
model_path,
7792
providers=[

0 commit comments

Comments
 (0)