Skip to content

Commit 7d88c84

Browse files
committed
fix(inference_engine): fix model type comparison and improve error handling
Closed #565
1 parent 5cd7de1 commit 7d88c84

File tree

1 file changed

+25
-9
lines changed
  • python/rapidocr/inference_engine

1 file changed

+25
-9
lines changed

python/rapidocr/inference_engine/base.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,25 +95,41 @@ def have_key(self, key: str = "character") -> bool:
9595

9696
@classmethod
9797
def get_model_url(cls, file_info: FileInfo) -> Dict[str, str]:
98+
engine_type = file_info.engine_type.value
99+
ocr_version = file_info.ocr_version.value
100+
task_type = file_info.task_type.value
101+
lang_type = file_info.lang_type.value
102+
model_type = file_info.model_type.value
103+
98104
model_dict = OmegaConf.select(
99-
cls.model_info,
100-
f"{file_info.engine_type.value}.{file_info.ocr_version.value}.{file_info.task_type.value}",
105+
cls.model_info, f"{engine_type}.{ocr_version}.{task_type}"
101106
)
102107

103108
# 优先查找 server 模型
104-
if file_info.model_type == ModelType.SERVER:
109+
if model_type == ModelType.SERVER.value:
105110
for k in model_dict:
106-
if (
107-
k.startswith(file_info.lang_type.value)
108-
and file_info.model_type.value in k
109-
):
111+
if k.startswith(lang_type) and model_type in k:
110112
return model_dict[k]
111113

112114
for k in model_dict:
113-
if k.startswith(file_info.lang_type.value):
115+
if k.startswith(lang_type):
114116
return model_dict[k]
115117

116-
raise KeyError("File not found")
118+
logger.error(
119+
"Unsupported configuration:\n"
120+
f" engine_type = {engine_type}\n"
121+
f" ocr_version = {ocr_version}\n"
122+
f" task_type = {task_type}\n"
123+
f" lang_type = {lang_type}\n"
124+
"\n"
125+
"Please refer to the official model list for supported combinations:\n"
126+
"https://rapidai.github.io/RapidOCRDocs/main/model_list/\n"
127+
"\n"
128+
"Example valid usage:\n"
129+
" from rapidocr import LangRec, OCRVersion, RapidOCR\n"
130+
" engine = RapidOCR(params={'Rec.ocr_version': OCRVersion.PPOCRV5, 'Rec.lang_type': LangRec.CH})",
131+
)
132+
raise ValueError("Invalid OCR configuration.")
117133

118134
@classmethod
119135
def get_dict_key_url(cls, file_info: FileInfo) -> str:

0 commit comments

Comments
 (0)