Skip to content

Commit 9c2c669

Browse files
authored
chore: dl default model from huggingface (#30)
We've now seen cases where dropbox links are blacklisted, which prevent them from downloading the default detectron2 model. This PR changes the default model retrieval method to use huggingface (which already hosts the weights and config for the model). The code is also updated so getting the default model and a non-default model uses more of the same code.
1 parent c25f2cd commit 9c2c669

File tree

11 files changed

+91
-72
lines changed

11 files changed

+91
-72
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
## 0.2.4-dev1
1+
## 0.2.4
22

3+
* Download default model from huggingface
34
* Clarify error when trying to open file that doesn't exist as an image
45

56
## 0.2.3

test_unstructured_inference/inference/test_layout.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from PIL import Image
99

1010
import unstructured_inference.inference.layout as layout
11-
import unstructured_inference.models as models
11+
import unstructured_inference.models.base as models
1212

1313
import unstructured_inference.models.detectron2 as detectron2
1414
import unstructured_inference.models.tesseract as tesseract
@@ -66,7 +66,9 @@ def detect(self, *args):
6666

6767

6868
def test_get_page_elements(monkeypatch, mock_page_layout):
69-
monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(mock_page_layout))
69+
monkeypatch.setattr(
70+
models, "load_model", lambda *args, **kwargs: MockLayoutModel(mock_page_layout)
71+
)
7072
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
7173

7274
image = np.random.randint(12, 24, (40, 40))
@@ -88,7 +90,7 @@ def test_get_page_elements_with_ocr(monkeypatch):
8890
text_block = TextBlock(rectangle, text=None, type="Title")
8991
doc_layout = Layout([text_block])
9092

91-
monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(doc_layout))
93+
monkeypatch.setattr(models, "load_model", lambda *args, **kwargs: MockLayoutModel(doc_layout))
9294
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
9395

9496
image = np.random.randint(12, 24, (40, 40))
@@ -104,7 +106,9 @@ def test_read_pdf(monkeypatch, mock_page_layout):
104106

105107
layouts = Layout([mock_page_layout, mock_page_layout])
106108

107-
monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(mock_page_layout))
109+
monkeypatch.setattr(
110+
models, "load_model", lambda *args, **kwargs: MockLayoutModel(mock_page_layout)
111+
)
108112
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
109113

110114
with patch.object(lp, "load_pdf", return_value=(layouts, images)):
@@ -138,6 +142,7 @@ def test_process_data_with_model(monkeypatch, mock_page_layout, model_name):
138142
"fake-binary-path",
139143
"fake-config-path",
140144
{0: "Unchecked", 1: "Checked"},
145+
None,
141146
),
142147
)
143148
with patch("builtins.open", mock_open(read_data=b"000000")):
@@ -168,6 +173,7 @@ def test_process_file_with_model(monkeypatch, mock_page_layout, model_name):
168173
"fake-binary-path",
169174
"fake-config-path",
170175
{0: "Unchecked", 1: "Checked"},
176+
None,
171177
),
172178
)
173179
filename = ""

test_unstructured_inference/models/test_detectron2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import patch
33

44
import unstructured_inference.models.detectron2 as detectron2
5+
import unstructured_inference.models.base as models
56

67

78
class MockDetectron2LayoutModel:
@@ -14,15 +15,15 @@ def test_load_default_model(monkeypatch):
1415
monkeypatch.setattr(detectron2, "Detectron2LayoutModel", MockDetectron2LayoutModel)
1516

1617
with patch.object(detectron2, "is_detectron2_available", return_value=True):
17-
model = detectron2.load_default_model()
18+
model = models.get_model()
1819

1920
assert isinstance(model, MockDetectron2LayoutModel)
2021

2122

2223
def test_load_default_model_raises_when_not_available():
2324
with patch.object(detectron2, "is_detectron2_available", return_value=False):
2425
with pytest.raises(ImportError):
25-
detectron2.load_default_model()
26+
models.get_model()
2627

2728

2829
@pytest.mark.parametrize("config_path, model_path", [("asdf", "diufs"), ("dfaw", "hfhfhfh")])

test_unstructured_inference/models/test_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from unstructured_inference import models
3+
import unstructured_inference.models.base as models
44

55

66
class MockModel:
@@ -18,6 +18,7 @@ def test_get_model(monkeypatch):
1818
"fake-binary-path",
1919
"fake-config-path",
2020
{0: "Unchecked", 1: "Checked"},
21+
None,
2122
),
2223
)
2324
assert isinstance(models.get_model("checkbox"), MockModel)

test_unstructured_inference/test_api.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastapi.testclient import TestClient
55

66
from unstructured_inference.api import app
7-
from unstructured_inference import models
7+
from unstructured_inference.models import base as models
88
from unstructured_inference.inference.layout import DocumentLayout
99
import unstructured_inference.models.detectron2 as detectron2
1010

@@ -15,8 +15,18 @@ def __init__(self, *args, **kwargs):
1515
self.kwargs = kwargs
1616

1717

18-
@pytest.mark.parametrize("filetype, ext", [("pdf", "pdf"), ("image", "png")])
19-
def test_layout_parsing_api(monkeypatch, filetype, ext):
18+
@pytest.mark.parametrize(
19+
"filetype, ext, data, response_code",
20+
[
21+
("pdf", "pdf", None, 200),
22+
("pdf", "pdf", {"model": "checkbox"}, 200),
23+
("pdf", "pdf", {"model": "fake_model"}, 422),
24+
("image", "png", None, 200),
25+
("image", "png", {"model": "checkbox"}, 200),
26+
("image", "png", {"model": "fake_model"}, 422),
27+
],
28+
)
29+
def test_layout_parsing_api(monkeypatch, filetype, ext, data, response_code):
2030
monkeypatch.setattr(models, "load_model", lambda *args, **kwargs: MockModel(*args, **kwargs))
2131
monkeypatch.setattr(models, "hf_hub_download", lambda *args, **kwargs: "fake-path")
2232
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
@@ -30,22 +40,10 @@ def test_layout_parsing_api(monkeypatch, filetype, ext):
3040
filename = os.path.join("sample-docs", f"loremipsum.{ext}")
3141

3242
client = TestClient(app)
33-
response = client.post(f"/layout/{filetype}", files={"file": (filename, open(filename, "rb"))})
34-
assert response.status_code == 200
35-
36-
response = client.post(
37-
f"/layout/{filetype}",
38-
files={"file": (filename, open(filename, "rb"))},
39-
data={"model": "checkbox"},
40-
)
41-
assert response.status_code == 200
42-
4343
response = client.post(
44-
f"/layout/{filetype}",
45-
files={"file": (filename, open(filename, "rb"))},
46-
data={"model": "fake_model"},
44+
f"/layout/{filetype}", files={"file": (filename, open(filename, "rb"))}, data=data
4745
)
48-
assert response.status_code == 422
46+
assert response.status_code == response_code
4947

5048

5149
def test_bad_route_404():
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.2.4-dev1" # pragma: no cover
1+
__version__ = "0.2.4" # pragma: no cover

unstructured_inference/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from fastapi import FastAPI, File, status, Request, UploadFile, Form, HTTPException
22
from unstructured_inference.inference.layout import process_data_with_model
3-
from unstructured_inference.models import UnknownModelException
3+
from unstructured_inference.models.base import UnknownModelException
44
from typing import List
55

66
app = FastAPI()

unstructured_inference/inference/layout.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
from unstructured_inference.logger import logger
1414
import unstructured_inference.models.tesseract as tesseract
15-
import unstructured_inference.models.detectron2 as detectron2
16-
from unstructured_inference.models import get_model
15+
from unstructured_inference.models.base import get_model
1716

1817

1918
@dataclass
@@ -113,7 +112,7 @@ def get_elements(self, inplace=True) -> Optional[List[LayoutElement]]:
113112
"""Uses a layoutparser model to detect the elements on the page."""
114113
logger.info("Detecting page elements ...")
115114
if self.model is None:
116-
self.model = detectron2.load_default_model()
115+
self.model = get_model()
117116

118117
elements = list()
119118
# NOTE(mrobinson) - We'll want make this model inference step some kind of
@@ -183,7 +182,7 @@ def process_file_with_model(
183182
) -> DocumentLayout:
184183
"""Processes pdf file with name filename into a DocumentLayout by using a model identified by
185184
model_name."""
186-
model = None if model_name is None else get_model(model_name)
185+
model = get_model(model_name)
187186
layout = (
188187
DocumentLayout.from_image_file(filename, model=model)
189188
if is_image
Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +0,0 @@
1-
from typing import Tuple, Dict
2-
from huggingface_hub import hf_hub_download
3-
4-
from unstructured_inference.models.detectron2 import load_model, Detectron2LayoutModel
5-
6-
7-
def get_model(model: str) -> Detectron2LayoutModel:
8-
"""Gets the model object by model name."""
9-
model_path, config_path, label_map = _get_model_loading_info(model)
10-
detector = load_model(config_path=config_path, model_path=model_path, label_map=label_map)
11-
12-
return detector
13-
14-
15-
def _get_model_loading_info(model: str) -> Tuple[str, str, Dict[int, str]]:
16-
"""Gets local model binary and config locations and label map, downloading if necessary."""
17-
# TODO(alan): Find the right way to map model name to retrieval. It seems off that testing
18-
# needs to mock hf_hub_download.
19-
if model == "checkbox":
20-
repo_id = "unstructuredio/oer-checkbox"
21-
binary_fn = "detectron2_finetuned_oer_checkbox.pth"
22-
config_fn = "detectron2_oer_checkbox.json"
23-
model_path = hf_hub_download(repo_id, binary_fn)
24-
config_path = hf_hub_download(repo_id, config_fn)
25-
label_map = {0: "Unchecked", 1: "Checked"}
26-
else:
27-
raise UnknownModelException(f"Unknown model type: {model}")
28-
return model_path, config_path, label_map
29-
30-
31-
class UnknownModelException(Exception):
32-
"""Exception for the case where a model is called for with an unrecognized identifier."""
33-
34-
pass
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Tuple, Dict, Optional, List, Any
2+
from huggingface_hub import hf_hub_download
3+
4+
from unstructured_inference.models.detectron2 import (
5+
load_model,
6+
Detectron2LayoutModel,
7+
DEFAULT_LABEL_MAP,
8+
DEFAULT_EXTRA_CONFIG,
9+
)
10+
11+
12+
def get_model(model: Optional[str] = None) -> Detectron2LayoutModel:
13+
"""Gets the model object by model name."""
14+
model_path, config_path, label_map, extra_config = _get_model_loading_info(model)
15+
detector = load_model(
16+
config_path=config_path,
17+
model_path=model_path,
18+
label_map=label_map,
19+
extra_config=extra_config,
20+
)
21+
return detector
22+
23+
24+
def _get_model_loading_info(
25+
model: Optional[str],
26+
) -> Tuple[str, str, Dict[int, str], Optional[List[Any]]]:
27+
"""Gets local model binary and config locations and label map, downloading if necessary."""
28+
# TODO(alan): Find the right way to map model name to retrieval. It seems off that testing
29+
# needs to mock hf_hub_download.
30+
if model is None:
31+
repo_id = "layoutparser/detectron2"
32+
binary_fn = "PubLayNet/faster_rcnn_R_50_FPN_3x/model_final.pth"
33+
config_fn = "PubLayNet/faster_rcnn_R_50_FPN_3x/config.yml"
34+
model_path = hf_hub_download(repo_id, binary_fn)
35+
config_path = hf_hub_download(repo_id, config_fn)
36+
label_map = DEFAULT_LABEL_MAP
37+
extra_config = DEFAULT_EXTRA_CONFIG
38+
elif model == "checkbox":
39+
repo_id = "unstructuredio/oer-checkbox"
40+
binary_fn = "detectron2_finetuned_oer_checkbox.pth"
41+
config_fn = "detectron2_oer_checkbox.json"
42+
model_path = hf_hub_download(repo_id, binary_fn)
43+
config_path = hf_hub_download(repo_id, config_fn)
44+
label_map = {0: "Unchecked", 1: "Checked"}
45+
extra_config = None
46+
else:
47+
raise UnknownModelException(f"Unknown model type: {model}")
48+
return model_path, config_path, label_map, extra_config
49+
50+
51+
class UnknownModelException(Exception):
52+
"""Exception for the case where a model is called for with an unrecognized identifier."""
53+
54+
pass

0 commit comments

Comments
 (0)