Skip to content

Commit 37d5a10

Browse files
committed
add MINERU_MODELS_DIR、npu
1 parent 88e5d51 commit 37d5a10

File tree

20 files changed

+271
-230
lines changed

20 files changed

+271
-230
lines changed

demo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
# os.environ['MINERU_DEVICE_MODE'] = "cuda"
1010
# # 或指定 GPU 编号,例如使用第二块 GPU(cuda:1)
1111
# os.environ['MINERU_DEVICE_MODE'] = "cuda:1"
12-
12+
# # 模型文件存储目录
13+
# os.environ['MINERU_MODELS_DIR'] = r'D:\CodeProjects\doc\RapidAI\models' #模型文件存储目录,如果不设置会默认下载到rapid_doc项目里面
1314
from loguru import logger
1415

1516
from rapid_doc.cli.common import convert_pdf_bytes_to_bytes_by_pypdfium2, prepare_env, read_fn

demo/demo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# os.environ['MINERU_DEVICE_MODE'] = "cuda"
88
# # 或指定 GPU 编号,例如使用第二块 GPU(cuda:1)
99
# os.environ['MINERU_DEVICE_MODE'] = "cuda:1"
10+
# # 模型文件存储目录
11+
# os.environ['MINERU_MODELS_DIR'] = r'D:\CodeProjects\doc\RapidAI\models' #模型文件存储目录,如果不设置会默认下载到rapid_doc项目里面
1012
from loguru import logger
1113

1214
from rapid_doc.cli.common import convert_pdf_bytes_to_bytes_by_pypdfium2, prepare_env, read_fn

docker/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ ENV API_PORT=8888
4343
ENV STARTUP_WAIT_TIME=15
4444
ENV LOG_LEVEL=INFO
4545
ENV MINERU_DEVICE_MODE=cpu
46+
ENV MINERU_MODELS_DIR=/app/models
4647
# 下载默认模型文件实现离线部署
4748
RUN python3 download_models.py
4849

docker/DockerfileGPU

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ ENV API_PORT=8888
5151
ENV STARTUP_WAIT_TIME=15
5252
ENV LOG_LEVEL=INFO
5353
ENV MINERU_DEVICE_MODE=cuda:0
54+
ENV MINERU_MODELS_DIR=/app/models
5455
# 下载默认模型文件实现离线部署
5556
RUN python3 download_models.py
5657

docker/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,4 @@ curl -X POST "http://localhost:8888/parse" \
7272
|--------|--------|------|
7373
| `STARTUP_WAIT_TIME` | `15` | 启动等待时间(秒) |
7474
| `LOG_LEVEL` | `INFO` | 日志级别 |
75+
| `MINERU_MODELS_DIR` | `/app/models` | 模型文件存储目录 |

docker/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _convert_value_to_enum(config):
8888
return config
8989
from rapidocr import EngineType as OCREngineType, OCRVersion, ModelType as OCRModelType, LangDet, LangRec
9090
from rapid_doc.model.layout.rapid_layout_self import ModelType as LayoutModelType
91-
from rapid_doc.model.formula.rapid_formula_self import ModelType as FormulaModelType
91+
from rapid_doc.model.formula.rapid_formula_self import ModelType as FormulaModelType, EngineType as FormulaEngineType
9292
from rapid_doc.model.table.rapid_table_self import ModelType as TableModelType
9393

9494
# 可识别的枚举类映射表(可扩展)
@@ -100,6 +100,7 @@ def _convert_value_to_enum(config):
100100
"LangRec": LangRec,
101101
"LayoutModelType": LayoutModelType,
102102
"FormulaModelType": FormulaModelType,
103+
"FormulaEngineType": FormulaEngineType,
103104
"TableModelType": TableModelType,
104105
}
105106

docker/download_models.py

Lines changed: 9 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,159 +1,14 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Model download script for Docker build
4+
Downloads pipeline models for offline deployment
5+
"""
6+
import os
17
import sys
2-
from pathlib import Path
3-
from typing import Union
4-
from loguru import logger
5-
from omegaconf import DictConfig, OmegaConf
6-
from download_file import DownloadFileInput, DownloadFile
7-
8-
def read_yaml(file_path: Union[str, Path]) -> DictConfig:
9-
return OmegaConf.load(file_path)
10-
11-
def default_download(models_pkg, configs_pkg):
12-
# 获取 models 模块的目录
13-
model_dir = Path(models_pkg.__path__[0])
14-
# 获取 configs 模块所在目录
15-
configs_dir = Path(configs_pkg.__file__).parent
16-
# 拼接 default_models.yaml 文件路径
17-
default_models_yaml = configs_dir / "default_models.yaml"
18-
model_map = read_yaml(default_models_yaml)
19-
20-
for model_name, model_info in model_map.items():
21-
if model_name in ['unitable']:
22-
# multi_models
23-
model_root_dir = model_info["model_dir_or_path"]
24-
save_model_dir = model_dir / Path(model_root_dir).name
25-
for file_name, sha256 in model_info["SHA256"].items():
26-
save_path = save_model_dir / file_name
27-
28-
download_params = DownloadFileInput(
29-
file_url=f"{model_root_dir}/{file_name}",
30-
sha256=sha256,
31-
save_path=save_path,
32-
)
33-
DownloadFile.run(download_params)
34-
elif model_name in ['onnxruntime', 'torch', 'openvino']:
35-
for name, item_model_info in model_info.items():
36-
model_dir_or_path = item_model_info["model_dir_or_path"]
37-
sha256 = item_model_info["SHA256"]
38-
save_model_path = (
39-
model_dir / Path(model_dir_or_path).name
40-
)
41-
download_params = DownloadFileInput(
42-
file_url=model_dir_or_path,
43-
sha256=sha256,
44-
save_path=save_model_path,
45-
)
46-
DownloadFile.run(download_params)
47-
48-
# 如果有字典文件,下载字典
49-
dict_download_url = item_model_info.get("dict_url")
50-
if dict_download_url:
51-
dict_path = (model_dir / Path(dict_download_url).name)
52-
if dict_download_url and not Path(dict_path).exists():
53-
DownloadFile.run(
54-
DownloadFileInput(
55-
file_url=dict_download_url,
56-
sha256=None,
57-
save_path=dict_path,
58-
)
59-
)
60-
else:
61-
model_dir_or_path = model_info["model_dir_or_path"]
62-
sha256 = model_info["SHA256"]
63-
64-
save_model_path = (
65-
model_dir / Path(model_dir_or_path).name
66-
)
67-
download_params = DownloadFileInput(
68-
file_url=model_dir_or_path,
69-
sha256=sha256,
70-
save_path=save_model_path,
71-
)
72-
DownloadFile.run(download_params)
73-
74-
def ocr_download(models_pkg, configs_pkg):
75-
# 获取 models 模块的目录
76-
model_dir = Path(models_pkg.__path__[0])
77-
# 获取 configs 模块所在目录
78-
configs_dir = Path(configs_pkg.__file__).parent
79-
# 拼接 default_models.yaml 文件路径
80-
default_models_yaml = configs_dir / "default_models.yaml"
81-
model_map = read_yaml(default_models_yaml)
82-
83-
for engin_name, engin_info in model_map.items(): # model_info为onnxruntime层级
84-
if engin_name in ['openvino', 'torch', 'fonts']:
85-
if engin_name == 'fonts':
86-
for lang, font_info in engin_info.items():
87-
font_path = font_info["path"]
88-
font_sha256 = font_info["SHA256"]
89-
90-
font_save_model_path = (
91-
model_dir / Path(font_path).name
92-
)
93-
download_params = DownloadFileInput(
94-
file_url=font_path,
95-
sha256=font_sha256,
96-
save_path=font_save_model_path,
97-
)
98-
DownloadFile.run(download_params)
99-
else:
100-
for version, ocr_info in engin_info.items(): # ocr_info为PP-OCRv4层级
101-
for det, det_info in ocr_info.items(): # info为det层级
102-
for model_name, model_info in det_info.items():
103-
# 如果有字典文件,下载字典
104-
dict_download_url = model_info.get("dict_url")
105-
if dict_download_url:
106-
dict_path = (model_dir / Path(dict_download_url).name)
107-
if dict_download_url and not Path(dict_path).exists():
108-
DownloadFile.run(
109-
DownloadFileInput(
110-
file_url=dict_download_url,
111-
sha256=None,
112-
save_path=dict_path,
113-
)
114-
)
115-
# 下载模型
116-
model_path = model_dir / Path(model_info["model_dir"]).name
117-
download_params = DownloadFileInput(
118-
file_url=model_info["model_dir"],
119-
sha256=model_info["SHA256"],
120-
save_path=model_path,
121-
)
122-
DownloadFile.run(download_params)
123-
124-
def download_pipeline_models():
125-
"""下载Pipeline模型"""
126-
try:
127-
# # 下载版面识别模型
128-
# logger.info('开始下载版面识别模型...')
129-
# import rapid_doc.model.layout.rapid_layout_self.models as layout_models_pkg
130-
# import rapid_doc.model.layout.rapid_layout_self.configs as layout_configs_pkg
131-
# default_download(layout_models_pkg, layout_configs_pkg)
132-
#
133-
# # 下载公式识别模型
134-
# logger.info('开始下载公式识别模型...')
135-
# import rapid_doc.model.formula.rapid_formula_self.models as formula_models_pkg
136-
# import rapid_doc.model.formula.rapid_formula_self.configs as formula_configs_pkg
137-
# default_download(formula_models_pkg, formula_configs_pkg)
138-
139-
# 下载表格识别模型
140-
logger.info('开始下载表格识别模型...')
141-
import rapid_doc.model.table.rapid_table_self.models as table_models_pkg
142-
import rapid_doc.model.table.rapid_table_self as table_configs_pkg
143-
default_download(table_models_pkg, table_configs_pkg)
144-
145-
# # 下载OCR模型
146-
# logger.info('开始下载OCR模型...')
147-
# import rapidocr.models as ocr_models_pkg
148-
# import rapidocr as ocr_configs_pkg
149-
# ocr_download(ocr_models_pkg, ocr_configs_pkg)
150-
# logger.info('所有模型下载完成: success download')
151-
return True
152-
except Exception as e:
153-
logger.error(f'模型下载失败: {e}')
154-
return True
155-
8+
from rapid_doc.utils.models_download_utils import download_pipeline_models
1569

15710
if __name__ == '__main__':
11+
os.environ['MINERU_MODELS_DIR'] = r'D:\CodeProjects\doc\RapidAI\models' #模型文件存储目录
12+
os.environ["MINERU_DEVICE_MODE"] = "cpu" # cpu、cuda、npu、all(all只是用来下载)
15813
success = download_pipeline_models()
15914
sys.exit(0 if success else 1)

docs/analyze_param.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ def doc_analyze(
1818
)
1919
```
2020
在mineru参数基础上新增了layout_config、ocr_config、formula_config、table_config、checkbox_config参数
21+
22+
#### 0、环境变量
23+
```bash
24+
# 用于指定推理设备。支持cpu/cuda/cuda:0/npu等设备类型
25+
os.environ['MINERU_DEVICE_MODE'] = "cpu"
26+
27+
# 模型文件存储目录。如果不设置会默认下载到rapid_doc项目里面
28+
os.environ['MINERU_MODELS_DIR'] = r'D:\CodeProjects\doc\RapidAI\models'
29+
```
2130
#### 1、使用gpu推理
2231
```bash
2332
# 在安装完 rapid_doc 之后,卸载 cpu 版的 onnxruntime

magic.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
"right": "$"
1414
}
1515
},
16-
"config_version": "1.3.0"
16+
"config_version": "1.3.1"
1717
}

rapid_doc/model/formula/rapid_formula_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
import os
12
import time
3+
from pathlib import Path
24

35
from rapid_doc.model.formula.rapid_formula_self import ModelType, RapidFormula, RapidFormulaInput, EngineType
46
from rapid_doc.utils.config_reader import get_device
7+
from rapid_doc.model.formula.rapid_formula_self.model_handler import ModelProcessor
8+
models_dir = os.getenv('MINERU_MODELS_DIR', None)
9+
if models_dir:
10+
# 从指定的文件夹内寻找模型文件
11+
ModelProcessor.DEFAULT_MODEL_DIR = Path(models_dir)
512

613
class RapidFormulaModel(object):
714
def __init__(self, formula_config=None):
@@ -14,6 +21,12 @@ def __init__(self, formula_config=None):
1421
cfg.engine_cfg = engine_cfg
1522
cfg.model_type = ModelType.PP_FORMULANET_PLUS_M
1623
cfg.engine_type = EngineType.TORCH
24+
elif device.startswith('npu'):
25+
device_id = int(device.split(':')[1]) if ':' in device else 0 # npu 编号
26+
engine_cfg = {'use_npu': True, "npu_id": device_id}
27+
cfg.engine_cfg = engine_cfg
28+
cfg.model_type = ModelType.PP_FORMULANET_PLUS_M
29+
cfg.engine_type = EngineType.TORCH
1730
# 如果传入了 formula_config,则用传入配置覆盖默认配置
1831
if formula_config is not None:
1932
# 遍历字典,把传入配置设置到 default_cfg 对象中

0 commit comments

Comments
 (0)