Skip to content

Commit 80fd937

Browse files
authored
Merge pull request #1672 from myhloli/dev
refactor(model): integrate Ascend plugin for NPU support
2 parents 6e1fba9 + f5112e2 commit 80fd937

File tree

3 files changed

+20
-41
lines changed

3 files changed

+20
-41
lines changed

magic_pdf/model/doc_analyze_by_custom_model.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import os
22
import time
3+
import torch
34

5+
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
6+
os.environ['FLAGS_use_stride_kernel'] = '0'
7+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
8+
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
49
# 关闭paddle的信号处理
510
import paddle
6-
import torch
11+
paddle.disable_signal_handler()
12+
713
from loguru import logger
814

915
from magic_pdf.model.batch_analyze import BatchAnalyze
1016
from magic_pdf.model.sub_modules.model_utils import get_vram
1117

12-
paddle.disable_signal_handler()
13-
14-
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
15-
1618
try:
1719
import torchtext
18-
1920
if torchtext.__version__ >= '0.18.0':
2021
torchtext.disable_torchtext_deprecation_warning()
2122
except ImportError:
@@ -32,20 +33,6 @@
3233
from magic_pdf.operators.models import InferenceResult
3334

3435

35-
def dict_compare(d1, d2):
36-
return d1.items() == d2.items()
37-
38-
39-
def remove_duplicates_dicts(lst):
40-
unique_dicts = []
41-
for dict_item in lst:
42-
if not any(
43-
dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
44-
):
45-
unique_dicts.append(dict_item)
46-
return unique_dicts
47-
48-
4936
class ModelSingleton:
5037
_instance = None
5138
_models = {}

magic_pdf/model/pdf_extract_kit.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,6 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
8989
# 初始化解析方案
9090
self.device = kwargs.get('device', 'cpu')
9191

92-
if str(self.device).startswith("npu"):
93-
import torch_npu
94-
os.environ['FLAGS_npu_jit_compile'] = '0'
95-
os.environ['FLAGS_use_stride_kernel'] = '0'
96-
elif str(self.device).startswith("mps"):
97-
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
98-
9992
logger.info('using device: {}'.format(self.device))
10093
models_dir = kwargs.get(
10194
'models_dir', os.path.join(root_dir, 'resources', 'models')

magic_pdf/model/sub_modules/model_init.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@
44
from magic_pdf.config.constants import MODEL_NAME
55
from magic_pdf.model.model_list import AtomicModel
66
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
7-
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
8-
DocLayoutYOLOModel
9-
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
10-
Layoutlmv3_Predictor
7+
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
8+
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
119
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
1210
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
13-
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
14-
ModifiedPaddleOCR
15-
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
16-
RapidTableModel
17-
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
18-
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
19-
StructTableModel
20-
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
21-
TableMasterPaddleModel
2211

12+
try:
13+
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
14+
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
15+
logger.info('Using Ascend Plugin')
16+
except ImportError:
17+
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
18+
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
19+
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
20+
21+
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
22+
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
2323

2424
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
2525
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
@@ -76,7 +76,6 @@ def ocr_model_init(show_log: bool = False,
7676
use_dilation=True,
7777
det_db_unclip_ratio=1.8,
7878
):
79-
8079
if lang is not None and lang != '':
8180
model = ModifiedPaddleOCR(
8281
show_log=show_log,

0 commit comments

Comments
 (0)