|
| 1 | +import os |
| 2 | +import logging |
| 3 | +from typing import Any, Optional |
| 4 | +from urllib.parse import urlparse |
| 5 | +import tarfile |
| 6 | +import uuid |
| 7 | + |
| 8 | +from iopath.common.file_io import PathHandler |
| 9 | +from iopath.common.file_io import HTTPURLHandler |
| 10 | +from iopath.common.file_io import get_cache_dir, file_lock |
| 11 | +from iopath.common.download import download |
| 12 | + |
| 13 | +from ..base_catalog import PathManager |
| 14 | + |
| 15 | +CONFIG_CATALOG = { |
| 16 | + "PubLayNet": { |
| 17 | + "ppyolov2_r50vd_dcn_365e_publaynet": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar", |
| 18 | + }, |
| 19 | + "TableBank": { |
| 20 | + "ppyolov2_r50vd_dcn_365e_tableBank_word": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar", |
| 21 | + "ppyolov2_r50vd_dcn_365e_tableBank_latex": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar", |
| 22 | + }, |
| 23 | +} |
| 24 | + |
| 25 | +# fmt: off |
| 26 | +LABEL_MAP_CATALOG = { |
| 27 | + "PubLayNet": { |
| 28 | + 0: "Text", |
| 29 | + 1: "Title", |
| 30 | + 2: "List", |
| 31 | + 3: "Table", |
| 32 | + 4: "Figure"}, |
| 33 | + "TableBank": { |
| 34 | + 0: "Table" |
| 35 | + }, |
| 36 | +} |
| 37 | +# fmt: on |
| 38 | + |
| 39 | + |
| 40 | +# Paddle model package everything in tar files, and each model's tar file should contain |
| 41 | +# the following files in the list: |
| 42 | +_TAR_FILE_NAME_LIST = [ |
| 43 | + "inference.pdiparams", |
| 44 | + "inference.pdiparams.info", |
| 45 | + "inference.pdmodel", |
| 46 | +] |
| 47 | + |
| 48 | + |
| 49 | +def _get_untar_directory(tar_file: str) -> str: |
| 50 | + |
| 51 | + base_path = os.path.dirname(tar_file) |
| 52 | + file_name = os.path.splitext(os.path.basename(tar_file))[0] |
| 53 | + target_folder = os.path.join(base_path, file_name) |
| 54 | + |
| 55 | + return target_folder |
| 56 | + |
| 57 | + |
| 58 | +def _untar_model_weights(model_tar): |
| 59 | + """untar model files""" |
| 60 | + |
| 61 | + model_dir = _get_untar_directory(model_tar) |
| 62 | + |
| 63 | + if not os.path.exists( |
| 64 | + os.path.join(model_dir, _TAR_FILE_NAME_LIST[0]) |
| 65 | + ) or not os.path.exists(os.path.join(model_dir, _TAR_FILE_NAME_LIST[2])): |
| 66 | + # the path to save the decompressed file |
| 67 | + os.makedirs(model_dir, exist_ok=True) |
| 68 | + with tarfile.open(model_tar, "r") as tarobj: |
| 69 | + for member in tarobj.getmembers(): |
| 70 | + filename = None |
| 71 | + for tar_file_name in _TAR_FILE_NAME_LIST: |
| 72 | + if tar_file_name in member.name: |
| 73 | + filename = tar_file_name |
| 74 | + if filename is None: |
| 75 | + continue |
| 76 | + file = tarobj.extractfile(member) |
| 77 | + with open(os.path.join(model_dir, filename), "wb") as model_file: |
| 78 | + model_file.write(file.read()) |
| 79 | + return model_dir |
| 80 | + |
| 81 | + |
| 82 | +def is_cached_folder_exists_and_valid(cached): |
| 83 | + possible_extracted_model_folder = _get_untar_directory(cached) |
| 84 | + if not os.path.exists(possible_extracted_model_folder): |
| 85 | + return False |
| 86 | + for tar_file in _TAR_FILE_NAME_LIST: |
| 87 | + if not os.path.exists(os.path.join(possible_extracted_model_folder, tar_file)): |
| 88 | + return False |
| 89 | + return True |
| 90 | + |
| 91 | + |
| 92 | +class PaddleModelURLHandler(HTTPURLHandler): |
| 93 | + """ |
| 94 | + Supports download and file check for Baidu Cloud links |
| 95 | + """ |
| 96 | + |
| 97 | + MAX_FILENAME_LEN = 250 |
| 98 | + |
| 99 | + def _get_supported_prefixes(self): |
| 100 | + return ["https://paddle-model-ecology.bj.bcebos.com"] |
| 101 | + |
| 102 | + def _isfile(self, path): |
| 103 | + return path in self.cache_map |
| 104 | + |
| 105 | + def _get_local_path( |
| 106 | + self, |
| 107 | + path: str, |
| 108 | + force: bool = False, |
| 109 | + cache_dir: Optional[str] = None, |
| 110 | + **kwargs: Any, |
| 111 | + ) -> str: |
| 112 | + """ |
| 113 | + As paddle model stores all files in tar files, we need to extract them |
| 114 | + and get the newly extracted folder path. This function rewrites the base |
| 115 | + function to support the following situations: |
| 116 | +
|
| 117 | + 1. If the tar file is not downloaded, it will download the tar file, |
| 118 | + extract it to the target folder, delete the downloaded tar file, |
| 119 | + and return the folder path. |
| 120 | + 2. If the extracted target folder is present, and all the necessary model |
| 121 | + files are present (specified in _TAR_FILE_NAME_LIST), it will |
| 122 | + return the folder path. |
| 123 | + 3. If the tar file is downloaded, but the extracted target folder is not |
| 124 | + present (or it doesn't contain the necessary files in _TAR_FILE_NAME_LIST), |
| 125 | + it will extract the tar file to the target folder, delete the tar file, |
| 126 | + and return the folder path. |
| 127 | +
|
| 128 | + """ |
| 129 | + self._check_kwargs(kwargs) |
| 130 | + if ( |
| 131 | + force |
| 132 | + or path not in self.cache_map |
| 133 | + or not os.path.exists(self.cache_map[path]) |
| 134 | + ): |
| 135 | + logger = logging.getLogger(__name__) |
| 136 | + parsed_url = urlparse(path) |
| 137 | + dirname = os.path.join( |
| 138 | + get_cache_dir(cache_dir), os.path.dirname(parsed_url.path.lstrip("/")) |
| 139 | + ) |
| 140 | + filename = path.split("/")[-1] |
| 141 | + if len(filename) > self.MAX_FILENAME_LEN: |
| 142 | + filename = filename[:100] + "_" + uuid.uuid4().hex |
| 143 | + |
| 144 | + cached = os.path.join(dirname, filename) |
| 145 | + |
| 146 | + if is_cached_folder_exists_and_valid(cached): |
| 147 | + # When the cached folder exists and valid, we don't need to redownload |
| 148 | + # the tar file. |
| 149 | + self.cache_map[path] = _get_untar_directory(cached) |
| 150 | + |
| 151 | + else: |
| 152 | + with file_lock(cached): |
| 153 | + if not os.path.isfile(cached): |
| 154 | + logger.info("Downloading {} ...".format(path)) |
| 155 | + cached = download(path, dirname, filename=filename) |
| 156 | + |
| 157 | + if path.endswith(".tar"): |
| 158 | + model_dir = _untar_model_weights(cached) |
| 159 | + try: |
| 160 | + os.remove(cached) # remove the redundant tar file |
| 161 | + # TODO: remove the .lock file . |
| 162 | + except: |
| 163 | + logger.warning( |
| 164 | + f"Not able to remove the cached tar file {cached}" |
| 165 | + ) |
| 166 | + |
| 167 | + logger.info("URL {} cached in {}".format(path, model_dir)) |
| 168 | + self.cache_map[path] = model_dir |
| 169 | + |
| 170 | + return self.cache_map[path] |
| 171 | + |
| 172 | + |
| 173 | +class LayoutParserPaddleModelHandler(PathHandler): |
| 174 | + """ |
| 175 | + Resolve anything that's in LayoutParser model zoo. |
| 176 | + """ |
| 177 | + |
| 178 | + PREFIX = "lp://paddledetection/" |
| 179 | + |
| 180 | + def _get_supported_prefixes(self): |
| 181 | + return [self.PREFIX] |
| 182 | + |
| 183 | + def _get_local_path(self, path, **kwargs): |
| 184 | + model_name = path[len(self.PREFIX) :] |
| 185 | + dataset_name, *model_name, data_type = model_name.split("/") |
| 186 | + |
| 187 | + if data_type == "config": |
| 188 | + model_url = CONFIG_CATALOG[dataset_name]["/".join(model_name)] |
| 189 | + else: |
| 190 | + raise ValueError(f"Unknown data_type {data_type}") |
| 191 | + return PathManager.get_local_path(model_url, **kwargs) |
| 192 | + |
| 193 | + def _open(self, path, mode="r", **kwargs): |
| 194 | + return PathManager.open(self._get_local_path(path), mode, **kwargs) |
| 195 | + |
| 196 | + |
| 197 | +PathManager.register_handler(PaddleModelURLHandler()) |
| 198 | +PathManager.register_handler(LayoutParserPaddleModelHandler()) |
0 commit comments