Skip to content

Commit 035f66a

Browse files
an1018lolipopshock
andauthored
Add PaddleDetection-based Layout Model (#54)
* add paddle model * Better model downloading logic * Use layout parser PathManager * simplify the layoutmodel in paddledetection * remove the empty preprocess.py file * incldue paddle models in dev-requirements Co-authored-by: Shannon Shen <[email protected]>
1 parent 2d35ab6 commit 035f66a

File tree

7 files changed

+526
-4
lines changed

7 files changed

+526
-4
lines changed

dev-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ sphinx_rtd_theme
1010
google-cloud-vision==1
1111
pytesseract
1212
pycocotools
13-
git+https://github.com/facebookresearch/[email protected]#egg=detectron2
13+
git+https://github.com/facebookresearch/[email protected]#egg=detectron2
14+
paddlepaddle

src/layoutparser/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
)
1616

1717
from .models import (
18-
Detectron2LayoutModel
18+
Detectron2LayoutModel,
19+
PaddleDetectionLayoutModel
1920
)
2021

2122
from .io import (
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .detectron2.layoutmodel import Detectron2LayoutModel
1+
from .detectron2.layoutmodel import Detectron2LayoutModel
2+
from .paddledetection.layoutmodel import PaddleDetectionLayoutModel
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import catalog as _UNUSED
2+
from .layoutmodel import PaddleDetectionLayoutModel
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import os
2+
import logging
3+
from typing import Any, Optional
4+
from urllib.parse import urlparse
5+
import tarfile
6+
import uuid
7+
8+
from iopath.common.file_io import PathHandler
9+
from iopath.common.file_io import HTTPURLHandler
10+
from iopath.common.file_io import get_cache_dir, file_lock
11+
from iopath.common.download import download
12+
13+
from ..base_catalog import PathManager
14+
15+
CONFIG_CATALOG = {
16+
"PubLayNet": {
17+
"ppyolov2_r50vd_dcn_365e_publaynet": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar",
18+
},
19+
"TableBank": {
20+
"ppyolov2_r50vd_dcn_365e_tableBank_word": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar",
21+
"ppyolov2_r50vd_dcn_365e_tableBank_latex": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar",
22+
},
23+
}
24+
25+
# fmt: off
26+
LABEL_MAP_CATALOG = {
27+
"PubLayNet": {
28+
0: "Text",
29+
1: "Title",
30+
2: "List",
31+
3: "Table",
32+
4: "Figure"},
33+
"TableBank": {
34+
0: "Table"
35+
},
36+
}
37+
# fmt: on
38+
39+
40+
# Paddle model package everything in tar files, and each model's tar file should contain
41+
# the following files in the list:
42+
_TAR_FILE_NAME_LIST = [
43+
"inference.pdiparams",
44+
"inference.pdiparams.info",
45+
"inference.pdmodel",
46+
]
47+
48+
49+
def _get_untar_directory(tar_file: str) -> str:
50+
51+
base_path = os.path.dirname(tar_file)
52+
file_name = os.path.splitext(os.path.basename(tar_file))[0]
53+
target_folder = os.path.join(base_path, file_name)
54+
55+
return target_folder
56+
57+
58+
def _untar_model_weights(model_tar):
59+
"""untar model files"""
60+
61+
model_dir = _get_untar_directory(model_tar)
62+
63+
if not os.path.exists(
64+
os.path.join(model_dir, _TAR_FILE_NAME_LIST[0])
65+
) or not os.path.exists(os.path.join(model_dir, _TAR_FILE_NAME_LIST[2])):
66+
# the path to save the decompressed file
67+
os.makedirs(model_dir, exist_ok=True)
68+
with tarfile.open(model_tar, "r") as tarobj:
69+
for member in tarobj.getmembers():
70+
filename = None
71+
for tar_file_name in _TAR_FILE_NAME_LIST:
72+
if tar_file_name in member.name:
73+
filename = tar_file_name
74+
if filename is None:
75+
continue
76+
file = tarobj.extractfile(member)
77+
with open(os.path.join(model_dir, filename), "wb") as model_file:
78+
model_file.write(file.read())
79+
return model_dir
80+
81+
82+
def is_cached_folder_exists_and_valid(cached):
83+
possible_extracted_model_folder = _get_untar_directory(cached)
84+
if not os.path.exists(possible_extracted_model_folder):
85+
return False
86+
for tar_file in _TAR_FILE_NAME_LIST:
87+
if not os.path.exists(os.path.join(possible_extracted_model_folder, tar_file)):
88+
return False
89+
return True
90+
91+
92+
class PaddleModelURLHandler(HTTPURLHandler):
93+
"""
94+
Supports download and file check for Baidu Cloud links
95+
"""
96+
97+
MAX_FILENAME_LEN = 250
98+
99+
def _get_supported_prefixes(self):
100+
return ["https://paddle-model-ecology.bj.bcebos.com"]
101+
102+
def _isfile(self, path):
103+
return path in self.cache_map
104+
105+
def _get_local_path(
106+
self,
107+
path: str,
108+
force: bool = False,
109+
cache_dir: Optional[str] = None,
110+
**kwargs: Any,
111+
) -> str:
112+
"""
113+
As paddle model stores all files in tar files, we need to extract them
114+
and get the newly extracted folder path. This function rewrites the base
115+
function to support the following situations:
116+
117+
1. If the tar file is not downloaded, it will download the tar file,
118+
extract it to the target folder, delete the downloaded tar file,
119+
and return the folder path.
120+
2. If the extracted target folder is present, and all the necessary model
121+
files are present (specified in _TAR_FILE_NAME_LIST), it will
122+
return the folder path.
123+
3. If the tar file is downloaded, but the extracted target folder is not
124+
present (or it doesn't contain the necessary files in _TAR_FILE_NAME_LIST),
125+
it will extract the tar file to the target folder, delete the tar file,
126+
and return the folder path.
127+
128+
"""
129+
self._check_kwargs(kwargs)
130+
if (
131+
force
132+
or path not in self.cache_map
133+
or not os.path.exists(self.cache_map[path])
134+
):
135+
logger = logging.getLogger(__name__)
136+
parsed_url = urlparse(path)
137+
dirname = os.path.join(
138+
get_cache_dir(cache_dir), os.path.dirname(parsed_url.path.lstrip("/"))
139+
)
140+
filename = path.split("/")[-1]
141+
if len(filename) > self.MAX_FILENAME_LEN:
142+
filename = filename[:100] + "_" + uuid.uuid4().hex
143+
144+
cached = os.path.join(dirname, filename)
145+
146+
if is_cached_folder_exists_and_valid(cached):
147+
# When the cached folder exists and valid, we don't need to redownload
148+
# the tar file.
149+
self.cache_map[path] = _get_untar_directory(cached)
150+
151+
else:
152+
with file_lock(cached):
153+
if not os.path.isfile(cached):
154+
logger.info("Downloading {} ...".format(path))
155+
cached = download(path, dirname, filename=filename)
156+
157+
if path.endswith(".tar"):
158+
model_dir = _untar_model_weights(cached)
159+
try:
160+
os.remove(cached) # remove the redundant tar file
161+
# TODO: remove the .lock file .
162+
except:
163+
logger.warning(
164+
f"Not able to remove the cached tar file {cached}"
165+
)
166+
167+
logger.info("URL {} cached in {}".format(path, model_dir))
168+
self.cache_map[path] = model_dir
169+
170+
return self.cache_map[path]
171+
172+
173+
class LayoutParserPaddleModelHandler(PathHandler):
174+
"""
175+
Resolve anything that's in LayoutParser model zoo.
176+
"""
177+
178+
PREFIX = "lp://paddledetection/"
179+
180+
def _get_supported_prefixes(self):
181+
return [self.PREFIX]
182+
183+
def _get_local_path(self, path, **kwargs):
184+
model_name = path[len(self.PREFIX) :]
185+
dataset_name, *model_name, data_type = model_name.split("/")
186+
187+
if data_type == "config":
188+
model_url = CONFIG_CATALOG[dataset_name]["/".join(model_name)]
189+
else:
190+
raise ValueError(f"Unknown data_type {data_type}")
191+
return PathManager.get_local_path(model_url, **kwargs)
192+
193+
def _open(self, path, mode="r", **kwargs):
194+
return PathManager.open(self._get_local_path(path), mode, **kwargs)
195+
196+
197+
PathManager.register_handler(PaddleModelURLHandler())
198+
PathManager.register_handler(LayoutParserPaddleModelHandler())

0 commit comments

Comments
 (0)