Skip to content

Commit 7852629

Browse files
authored
disable ir optim for nptag (#3107)
1 parent 08636b3 commit 7852629

File tree

2 files changed

+4
-21
lines changed

2 files changed

+4
-21
lines changed

examples/text_to_knowledge/nptag/deploy/python/predict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,12 @@ def __init__(self, model_dir, device):
4949
if not os.path.exists(params_file):
5050
raise ValueError("not find params file path {}".format(params_file))
5151
config = paddle.inference.Config(model_file, params_file)
52+
# Disable IR optimization for NPTag
53+
config.switch_ir_optim(False)
5254

5355
if device == "gpu":
5456
# set GPU configs accordingly
5557
config.enable_use_gpu(100, 0)
56-
config.delete_pass("embedding_eltwise_layernorm_fuse_pass")
5758
elif device == "cpu":
5859
# set CPU configs accordingly,
5960
# such as enable_mkldnn, set_cpu_math_library_num_threads

paddlenlp/taskflow/knowledge_mining.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -580,26 +580,8 @@ def __init__(self,
580580
self._construct_dict_map()
581581

582582
self._get_inference_model()
583-
if paddle.get_device().startswith("gpu"):
584-
inference_model_path = os.path.join(self._task_path, "static",
585-
"inference")
586-
model_file = inference_model_path + ".pdmodel"
587-
params_file = inference_model_path + ".pdiparams"
588-
self._config = paddle.inference.Config(model_file, params_file)
589-
self._config.enable_use_gpu(100, 0)
590-
self._config.switch_use_feed_fetch_ops(False)
591-
self._config.disable_glog_info()
592-
# TODO(linjieccc): enable embedding_eltwise_layernorm_fuse_pass after fixed
593-
self._config.delete_pass("embedding_eltwise_layernorm_fuse_pass")
594-
self.predictor = paddle.inference.create_predictor(self._config)
595-
self.input_handles = [
596-
self.predictor.get_input_handle(name)
597-
for name in self.predictor.get_input_names()
598-
]
599-
self.output_handle = [
600-
self.predictor.get_output_handle(name)
601-
for name in self.predictor.get_output_names()
602-
]
583+
# Disable IR optimization for NPTag
584+
self._config.switch_ir_optim(False)
603585

604586
@property
605587
def summary_num(self):

0 commit comments

Comments
 (0)