Skip to content

Commit 2d35ab6

Browse files
authored
Improve models structure (#53)
* reorganize model folder * Improve model catalogs * layout model needs to have DETECTOR_NAME
1 parent 861e0a0 commit 2d35ab6

File tree

6 files changed

+124
-72
lines changed

6 files changed

+124
-72
lines changed
Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
1-
from . import catalog as _UNUSED
2-
# A trick learned from
3-
# https://github.com/facebookresearch/detectron2/blob/62cf3a2b6840734d2717abdf96e2dd57ed6612a6/detectron2/checkpoint/__init__.py#L6
4-
from .layoutmodel import Detectron2LayoutModel
1+
from .detectron2.layoutmodel import Detectron2LayoutModel
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from iopath.common.file_io import HTTPURLHandler
2+
from iopath.common.file_io import PathManager as PathManagerBase
3+
4+
# A trick learned from https://github.com/facebookresearch/detectron2/blob/65faeb4779e4c142484deeece18dc958c5c9ad18/detectron2/utils/file_io.py#L3
5+
6+
7+
class DropboxHandler(HTTPURLHandler):
8+
"""
9+
Supports download and file check for dropbox links
10+
"""
11+
12+
def _get_supported_prefixes(self):
13+
return ["https://www.dropbox.com"]
14+
15+
def _isfile(self, path):
16+
return path in self.cache_map
17+
18+
19+
PathManager = PathManagerBase()
20+
PathManager.register_handler(DropboxHandler())
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from abc import ABC, abstractmethod
2+
import os
3+
import importlib
4+
5+
6+
class BaseLayoutModel(ABC):
7+
8+
@property
9+
@abstractmethod
10+
def DETECTOR_NAME(self):
11+
pass
12+
13+
@abstractmethod
14+
def detect(self):
15+
pass
16+
17+
# Add lazy loading mechanisms for layout models, refer to
18+
# layoutparser.ocr.BaseOCRAgent
19+
# TODO: Build a metaclass for lazy module loader
20+
@property
21+
@abstractmethod
22+
def DEPENDENCIES(self):
23+
"""DEPENDENCIES lists all necessary dependencies for the class."""
24+
pass
25+
26+
@property
27+
@abstractmethod
28+
def MODULES(self):
29+
"""MODULES instructs how to import these necessary libraries."""
30+
pass
31+
32+
@classmethod
33+
def _import_module(cls):
34+
for m in cls.MODULES:
35+
if importlib.util.find_spec(m["module_path"]):
36+
setattr(
37+
cls, m["import_name"], importlib.import_module(m["module_path"])
38+
)
39+
else:
40+
raise ModuleNotFoundError(
41+
f"\n "
42+
f"\nPlease install the following libraries to support the class {cls.__name__}:"
43+
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
44+
f"\n "
45+
)
46+
47+
def __new__(cls, *args, **kwargs):
48+
49+
cls._import_module()
50+
return super().__new__(cls)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import catalog as _UNUSED
2+
# A trick learned from
3+
# https://github.com/facebookresearch/detectron2/blob/62cf3a2b6840734d2717abdf96e2dd57ed6612a6/detectron2/checkpoint/__init__.py#L6
4+
from .layoutmodel import Detectron2LayoutModel

src/layoutparser/models/catalog.py renamed to src/layoutparser/models/detectron2/catalog.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from iopath.common.file_io import PathHandler, PathManager, HTTPURLHandler
2-
from iopath.common.file_io import PathManager as PathManagerBase
1+
from iopath.common.file_io import PathHandler
32

4-
# A trick learned from https://github.com/facebookresearch/detectron2/blob/65faeb4779e4c142484deeece18dc958c5c9ad18/detectron2/utils/file_io.py#L3
3+
from ..base_catalog import PathManager
54

65
MODEL_CATALOG = {
76
"HJDataset": {
@@ -49,6 +48,7 @@
4948
},
5049
}
5150

51+
# fmt: off
5252
LABEL_MAP_CATALOG = {
5353
"HJDataset": {
5454
1: "Page Frame",
@@ -59,7 +59,12 @@
5959
6: "Subtitle",
6060
7: "Other",
6161
},
62-
"PubLayNet": {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"},
62+
"PubLayNet": {
63+
0: "Text",
64+
1: "Title",
65+
2: "List",
66+
3: "Table",
67+
4: "Figure"},
6368
"PrimaLayout": {
6469
1: "TextRegion",
6570
2: "ImageRegion",
@@ -77,34 +82,26 @@
7782
5: "Headline",
7883
6: "Advertisement",
7984
},
80-
"TableBank": {0: "Table"},
85+
"TableBank": {
86+
0: "Table"
87+
},
8188
}
89+
# fmt: on
8290

8391

84-
class DropboxHandler(HTTPURLHandler):
85-
"""
86-
Supports download and file check for dropbox links
87-
"""
88-
89-
def _get_supported_prefixes(self):
90-
return ["https://www.dropbox.com"]
91-
92-
def _isfile(self, path):
93-
return path in self.cache_map
94-
95-
96-
class LayoutParserHandler(PathHandler):
92+
class LayoutParserDetectron2ModelHandler(PathHandler):
9793
"""
9894
Resolve anything that's in LayoutParser model zoo.
9995
"""
10096

101-
PREFIX = "lp://"
97+
PREFIX = "lp://detectron2/"
10298

10399
def _get_supported_prefixes(self):
104100
return [self.PREFIX]
105101

106102
def _get_local_path(self, path, **kwargs):
107103
model_name = path[len(self.PREFIX) :]
104+
108105
dataset_name, *model_name, data_type = model_name.split("/")
109106

110107
if data_type == "weight":
@@ -119,6 +116,4 @@ def _open(self, path, mode="r", **kwargs):
119116
return PathManager.open(self._get_local_path(path), mode, **kwargs)
120117

121118

122-
PathManager = PathManagerBase()
123-
PathManager.register_handler(DropboxHandler())
124-
PathManager.register_handler(LayoutParserHandler())
119+
PathManager.register_handler(LayoutParserDetectron2ModelHandler())

src/layoutparser/models/layoutmodel.py renamed to src/layoutparser/models/detectron2/layoutmodel.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,14 @@
1-
from abc import ABC, abstractmethod
2-
import os
3-
import importlib
4-
51
from PIL import Image
62
import numpy as np
73
import torch
84

95
from .catalog import PathManager, LABEL_MAP_CATALOG
10-
from ..elements import *
6+
from ..base_layoutmodel import BaseLayoutModel
7+
from ...elements import Rectangle, TextBlock, Layout
118

129
__all__ = ["Detectron2LayoutModel"]
1310

1411

15-
class BaseLayoutModel(ABC):
16-
@abstractmethod
17-
def detect(self):
18-
pass
19-
20-
# Add lazy loading mechanisms for layout models, refer to
21-
# layoutparser.ocr.BaseOCRAgent
22-
# TODO: Build a metaclass for lazy module loader
23-
@property
24-
@abstractmethod
25-
def DEPENDENCIES(self):
26-
"""DEPENDENCIES lists all necessary dependencies for the class."""
27-
pass
28-
29-
@property
30-
@abstractmethod
31-
def MODULES(self):
32-
"""MODULES instructs how to import these necessary libraries."""
33-
pass
34-
35-
@classmethod
36-
def _import_module(cls):
37-
for m in cls.MODULES:
38-
if importlib.util.find_spec(m["module_path"]):
39-
setattr(
40-
cls, m["import_name"], importlib.import_module(m["module_path"])
41-
)
42-
else:
43-
raise ModuleNotFoundError(
44-
f"\n "
45-
f"\nPlease install the following libraries to support the class {cls.__name__}:"
46-
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
47-
f"\n "
48-
)
49-
50-
def __new__(cls, *args, **kwargs):
51-
52-
cls._import_module()
53-
return super().__new__(cls)
54-
55-
5612
class Detectron2LayoutModel(BaseLayoutModel):
5713
"""Create a Detectron2-based Layout Detection Model
5814
@@ -93,6 +49,7 @@ class Detectron2LayoutModel(BaseLayoutModel):
9349
},
9450
{"import_name": "_config", "module_path": "detectron2.config"},
9551
]
52+
DETECTOR_NAME = "detectron2"
9653

9754
def __init__(
9855
self,
@@ -111,18 +68,47 @@ def __init__(
11168
extra_config.extend(["MODEL.DEVICE", "cpu"])
11269

11370
cfg = self._config.get_cfg()
71+
config_path = self._reconstruct_path_with_detector_name(config_path)
11472
config_path = PathManager.get_local_path(config_path)
11573
cfg.merge_from_file(config_path)
11674
cfg.merge_from_list(extra_config)
11775

11876
if model_path is not None:
77+
model_path = self._reconstruct_path_with_detector_name(model_path)
11978
cfg.MODEL.WEIGHTS = model_path
12079
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12180
self.cfg = cfg
12281

12382
self.label_map = label_map
12483
self._create_model()
12584

85+
def _reconstruct_path_with_detector_name(self, path: str) -> str:
86+
"""This function will add the detector name (detectron2) into the
87+
lp model config path to get the "canonical" model name.
88+
89+
For example, for a given config_path `lp://HJDataset/faster_rcnn_R_50_FPN_3x/config`,
90+
it will transform it into `lp://detectron2/HJDataset/faster_rcnn_R_50_FPN_3x/config`.
91+
However, if the config_path already contains the detector name, we won't change it.
92+
93+
This function is a general step to support multiple backends in the layout-parser
94+
library.
95+
96+
Args:
97+
path (str): The given input path that might or might not contain the detector name.
98+
99+
Returns:
100+
str: a modified path that contains the detector name.
101+
"""
102+
if path.startswith("lp://"): # TODO: Move "lp://" to a constant
103+
model_name = path[len("lp://") :]
104+
model_name_segments = model_name.split("/")
105+
if (
106+
len(model_name_segments) == 3
107+
and "detectron2" not in model_name_segments
108+
):
109+
return "lp://" + self.DETECTOR_NAME + "/" + path[len("lp://") :]
110+
return path
111+
126112
def gather_output(self, outputs):
127113

128114
instance_pred = outputs["instances"].to("cpu")

0 commit comments

Comments
 (0)