|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Model download script for Docker build |
| 4 | +Downloads pipeline models for offline deployment |
| 5 | +""" |
| 6 | +import os |
1 | 7 | import sys |
2 | | -from pathlib import Path |
3 | | -from typing import Union |
4 | | -from loguru import logger |
5 | | -from omegaconf import DictConfig, OmegaConf |
6 | | -from download_file import DownloadFileInput, DownloadFile |
7 | | - |
8 | | -def read_yaml(file_path: Union[str, Path]) -> DictConfig: |
9 | | - return OmegaConf.load(file_path) |
10 | | - |
11 | | -def default_download(models_pkg, configs_pkg): |
12 | | - # 获取 models 模块的目录 |
13 | | - model_dir = Path(models_pkg.__path__[0]) |
14 | | - # 获取 configs 模块所在目录 |
15 | | - configs_dir = Path(configs_pkg.__file__).parent |
16 | | - # 拼接 default_models.yaml 文件路径 |
17 | | - default_models_yaml = configs_dir / "default_models.yaml" |
18 | | - model_map = read_yaml(default_models_yaml) |
19 | | - |
20 | | - for model_name, model_info in model_map.items(): |
21 | | - if model_name in ['unitable']: |
22 | | - # multi_models |
23 | | - model_root_dir = model_info["model_dir_or_path"] |
24 | | - save_model_dir = model_dir / Path(model_root_dir).name |
25 | | - for file_name, sha256 in model_info["SHA256"].items(): |
26 | | - save_path = save_model_dir / file_name |
27 | | - |
28 | | - download_params = DownloadFileInput( |
29 | | - file_url=f"{model_root_dir}/{file_name}", |
30 | | - sha256=sha256, |
31 | | - save_path=save_path, |
32 | | - ) |
33 | | - DownloadFile.run(download_params) |
34 | | - elif model_name in ['onnxruntime', 'torch', 'openvino']: |
35 | | - for name, item_model_info in model_info.items(): |
36 | | - model_dir_or_path = item_model_info["model_dir_or_path"] |
37 | | - sha256 = item_model_info["SHA256"] |
38 | | - save_model_path = ( |
39 | | - model_dir / Path(model_dir_or_path).name |
40 | | - ) |
41 | | - download_params = DownloadFileInput( |
42 | | - file_url=model_dir_or_path, |
43 | | - sha256=sha256, |
44 | | - save_path=save_model_path, |
45 | | - ) |
46 | | - DownloadFile.run(download_params) |
47 | | - |
48 | | - # 如果有字典文件,下载字典 |
49 | | - dict_download_url = item_model_info.get("dict_url") |
50 | | - if dict_download_url: |
51 | | - dict_path = (model_dir / Path(dict_download_url).name) |
52 | | - if dict_download_url and not Path(dict_path).exists(): |
53 | | - DownloadFile.run( |
54 | | - DownloadFileInput( |
55 | | - file_url=dict_download_url, |
56 | | - sha256=None, |
57 | | - save_path=dict_path, |
58 | | - ) |
59 | | - ) |
60 | | - else: |
61 | | - model_dir_or_path = model_info["model_dir_or_path"] |
62 | | - sha256 = model_info["SHA256"] |
63 | | - |
64 | | - save_model_path = ( |
65 | | - model_dir / Path(model_dir_or_path).name |
66 | | - ) |
67 | | - download_params = DownloadFileInput( |
68 | | - file_url=model_dir_or_path, |
69 | | - sha256=sha256, |
70 | | - save_path=save_model_path, |
71 | | - ) |
72 | | - DownloadFile.run(download_params) |
73 | | - |
74 | | -def ocr_download(models_pkg, configs_pkg): |
75 | | - # 获取 models 模块的目录 |
76 | | - model_dir = Path(models_pkg.__path__[0]) |
77 | | - # 获取 configs 模块所在目录 |
78 | | - configs_dir = Path(configs_pkg.__file__).parent |
79 | | - # 拼接 default_models.yaml 文件路径 |
80 | | - default_models_yaml = configs_dir / "default_models.yaml" |
81 | | - model_map = read_yaml(default_models_yaml) |
82 | | - |
83 | | - for engin_name, engin_info in model_map.items(): # model_info为onnxruntime层级 |
84 | | - if engin_name in ['openvino', 'torch', 'fonts']: |
85 | | - if engin_name == 'fonts': |
86 | | - for lang, font_info in engin_info.items(): |
87 | | - font_path = font_info["path"] |
88 | | - font_sha256 = font_info["SHA256"] |
89 | | - |
90 | | - font_save_model_path = ( |
91 | | - model_dir / Path(font_path).name |
92 | | - ) |
93 | | - download_params = DownloadFileInput( |
94 | | - file_url=font_path, |
95 | | - sha256=font_sha256, |
96 | | - save_path=font_save_model_path, |
97 | | - ) |
98 | | - DownloadFile.run(download_params) |
99 | | - else: |
100 | | - for version, ocr_info in engin_info.items(): # ocr_info为PP-OCRv4层级 |
101 | | - for det, det_info in ocr_info.items(): # info为det层级 |
102 | | - for model_name, model_info in det_info.items(): |
103 | | - # 如果有字典文件,下载字典 |
104 | | - dict_download_url = model_info.get("dict_url") |
105 | | - if dict_download_url: |
106 | | - dict_path = (model_dir / Path(dict_download_url).name) |
107 | | - if dict_download_url and not Path(dict_path).exists(): |
108 | | - DownloadFile.run( |
109 | | - DownloadFileInput( |
110 | | - file_url=dict_download_url, |
111 | | - sha256=None, |
112 | | - save_path=dict_path, |
113 | | - ) |
114 | | - ) |
115 | | - # 下载模型 |
116 | | - model_path = model_dir / Path(model_info["model_dir"]).name |
117 | | - download_params = DownloadFileInput( |
118 | | - file_url=model_info["model_dir"], |
119 | | - sha256=model_info["SHA256"], |
120 | | - save_path=model_path, |
121 | | - ) |
122 | | - DownloadFile.run(download_params) |
123 | | - |
124 | | -def download_pipeline_models(): |
125 | | - """下载Pipeline模型""" |
126 | | - try: |
127 | | - # # 下载版面识别模型 |
128 | | - # logger.info('开始下载版面识别模型...') |
129 | | - # import rapid_doc.model.layout.rapid_layout_self.models as layout_models_pkg |
130 | | - # import rapid_doc.model.layout.rapid_layout_self.configs as layout_configs_pkg |
131 | | - # default_download(layout_models_pkg, layout_configs_pkg) |
132 | | - # |
133 | | - # # 下载公式识别模型 |
134 | | - # logger.info('开始下载公式识别模型...') |
135 | | - # import rapid_doc.model.formula.rapid_formula_self.models as formula_models_pkg |
136 | | - # import rapid_doc.model.formula.rapid_formula_self.configs as formula_configs_pkg |
137 | | - # default_download(formula_models_pkg, formula_configs_pkg) |
138 | | - |
139 | | - # 下载表格识别模型 |
140 | | - logger.info('开始下载表格识别模型...') |
141 | | - import rapid_doc.model.table.rapid_table_self.models as table_models_pkg |
142 | | - import rapid_doc.model.table.rapid_table_self as table_configs_pkg |
143 | | - default_download(table_models_pkg, table_configs_pkg) |
144 | | - |
145 | | - # # 下载OCR模型 |
146 | | - # logger.info('开始下载OCR模型...') |
147 | | - # import rapidocr.models as ocr_models_pkg |
148 | | - # import rapidocr as ocr_configs_pkg |
149 | | - # ocr_download(ocr_models_pkg, ocr_configs_pkg) |
150 | | - # logger.info('所有模型下载完成: success download') |
151 | | - return True |
152 | | - except Exception as e: |
153 | | - logger.error(f'模型下载失败: {e}') |
154 | | - return True |
155 | | - |
| 8 | +from rapid_doc.utils.models_download_utils import download_pipeline_models |
156 | 9 |
|
157 | 10 | if __name__ == '__main__': |
| 11 | + os.environ['MINERU_MODELS_DIR'] = r'D:\CodeProjects\doc\RapidAI\models' #模型文件存储目录 |
| 12 | + os.environ["MINERU_DEVICE_MODE"] = "cpu" # cpu、cuda、npu、all(all只是用来下载) |
158 | 13 | success = download_pipeline_models() |
159 | 14 | sys.exit(0 if success else 1) |
0 commit comments