Skip to content

Commit ff20502

Browse files
committed
chore: refacter torch infersession
1 parent eb6c685 commit ff20502

File tree

3 files changed

+67
-29
lines changed

3 files changed

+67
-29
lines changed

python/demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010
print(result)
1111

1212
result.vis("vis_result.jpg")
13-
print(result.to_markdown())
13+
print(result)

python/rapidocr/inference_engine/torch.py

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..networks.architectures.base_model import BaseModel
1111
from ..utils.download_file import DownloadFile, DownloadFileInput
1212
from ..utils.log import logger
13+
from ..utils.utils import mkdir
1314
from .base import FileInfo, InferSession
1415

1516
root_dir = Path(__file__).resolve().parent.parent
@@ -18,8 +19,16 @@
1819

1920
class TorchInferSession(InferSession):
2021
def __init__(self, cfg) -> None:
21-
self.logger = logger
22+
model_path = self._init_model_path(cfg)
23+
arch_config = self._load_arch_config(model_path)
2224

25+
self.predictor = self._build_and_load_model(arch_config, model_path)
26+
27+
self._setup_device(cfg)
28+
29+
self.predictor.eval()
30+
31+
def _init_model_path(self, cfg) -> Path:
2332
model_path = cfg.get("model_path", None)
2433
if model_path is None:
2534
model_info = self.get_model_url(
@@ -38,44 +47,69 @@ def __init__(self, cfg) -> None:
3847
file_url=default_model_url,
3948
sha256=model_info["SHA256"],
4049
save_path=model_path,
41-
logger=self.logger,
50+
logger=logger,
4251
)
4352
)
4453

45-
self.logger.info(f"Using {model_path}")
46-
model_path = Path(model_path)
54+
logger.info(f"Using {model_path}")
4755
self._verify_model(model_path)
56+
return Path(model_path)
4857

58+
def _load_arch_config(self, model_path: Path):
4959
all_arch_config = OmegaConf.load(DEFAULT_CFG_PATH)
60+
5061
file_name = model_path.stem
5162
if file_name not in all_arch_config:
5263
raise ValueError(f"architecture {file_name} is not in arch_config.yaml")
5364

54-
arch_config = all_arch_config.get(file_name)
55-
self.predictor = BaseModel(arch_config)
56-
self.predictor.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
57-
self.predictor.eval()
58-
self.use_gpu = False
59-
self.use_npu = False
65+
return all_arch_config.get(file_name)
66+
67+
def _build_and_load_model(self, arch_config, model_path: Path):
68+
model = BaseModel(arch_config)
69+
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
70+
model.load_state_dict(state_dict)
71+
return model
72+
73+
def _setup_device(self, cfg):
74+
self.device, self.use_gpu, self.use_npu = self._resolve_device_config(cfg)
75+
76+
if self.use_npu:
77+
self._config_npu()
78+
79+
self._move_model_to_device()
80+
81+
def _resolve_device_config(self, cfg):
6082
if cfg.engine_cfg.use_cuda:
61-
self.device = torch.device(f"cuda:{cfg.engine_cfg.gpu_id}")
62-
self.predictor.to(self.device)
63-
self.use_gpu = True
64-
elif cfg.engine_cfg.use_npu:
65-
try:
66-
import torch_npu
67-
options = {
68-
# 设定算子编译的磁盘缓存模式,非必要每次重新编译
69-
"ACL_OP_COMPILER_CACHE_MODE": "enable",
70-
# 指定缓存目录,确保路径已存在
71-
"ACL_OP_COMPILER_CACHE_DIR": "./kernel_meta",
72-
}
73-
torch_npu.npu.set_option(options)
74-
except ImportError:
75-
self.logger.warning("torch_npu is not installed, options with ACL setting failed.")
76-
self.device = torch.device(f"npu:{cfg.engine_cfg.npu_id}")
77-
self.predictor.to(self.device)
78-
self.use_npu = True
83+
return torch.device(f"cuda:{cfg.engine_cfg.gpu_id}"), True, False
84+
85+
if cfg.engine_cfg.use_npu:
86+
return torch.device(f"npu:{cfg.engine_cfg.npu_id}"), False, True
87+
88+
return torch.device("cpu"), False, False
89+
90+
def _config_npu(self):
91+
try:
92+
import torch_npu
93+
94+
kernel_meta_dir = (root_dir / "kernel_meta").resolve()
95+
mkdir(kernel_meta_dir)
96+
97+
options = {
98+
"ACL_OP_COMPILER_CACHE_MODE": "enable",
99+
"ACL_OP_COMPILER_CACHE_DIR": str(kernel_meta_dir),
100+
}
101+
torch_npu.npu.set_option(options)
102+
except ImportError:
103+
logger.warning(
104+
"torch_npu is not installed, options with ACL setting failed. \n"
105+
"Please refer to https://github.com/Ascend/pytorch to see how to install."
106+
)
107+
108+
self.device = torch.device("cpu")
109+
self.use_npu = False
110+
111+
def _move_model_to_device(self):
112+
self.predictor.to(self.device)
79113

80114
def __call__(self, img: np.ndarray):
81115
with torch.no_grad():

python/rapidocr/utils/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
import numpy as np
1313

1414

15+
def mkdir(dir_path):
16+
Path(dir_path).mkdir(parents=True, exist_ok=True)
17+
18+
1519
def quads_to_rect_bbox(bbox: np.ndarray) -> Tuple[float, float, float, float]:
1620
if bbox.ndim != 3:
1721
raise ValueError("bbox shape must be 3")

0 commit comments

Comments
 (0)