Skip to content

Commit fb3e4c0

Browse files
authored
[Infer] Add pir_model path for server infer. (PaddlePaddle#9790)
1 parent d039ad2 commit fb3e4c0

File tree

1 file changed

+7
-22
lines changed

1 file changed

+7
-22
lines changed

llm/server/server/server/engine/infer.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import paddle
2626
import paddle.distributed as dist
2727
import paddle.distributed.fleet as fleet
28+
from paddle.base.framework import use_pir_api
2829
from paddlenlp.trl.llm_utils import get_rotary_position_embedding
2930
from paddlenlp_ops import step_paddle
3031
from server.data.processor import DataProcessor
@@ -467,32 +468,16 @@ def _init_predictor(self):
467468
predictor init
468469
"""
469470
device_id = self.rank % 8
470-
self.model_file = os.path.join(self.model_dir, f"model.pdmodel")
471-
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
471+
if use_pir_api():
472+
self.model_file = os.path.join(self.model_dir, f"model.json")
473+
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
474+
else:
475+
self.model_file = os.path.join(self.model_dir, f"model.pdmodel")
476+
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
472477
config = paddle.inference.Config(self.model_file, self.param_file)
473478

474-
config.switch_ir_optim(False)
475479
config.enable_use_gpu(100, device_id)
476480

477-
# distributed config
478-
if self.mp_degree > 1:
479-
trainer_endpoints = fleet.worker_endpoints()
480-
current_endpoint = trainer_endpoints[self.rank]
481-
dist_config = config.dist_config()
482-
dist_config.set_ranks(self.nranks, self.rank)
483-
dist_config.set_endpoints(trainer_endpoints, current_endpoint)
484-
dist_config.enable_dist_model(True)
485-
if self.config.distributed_config_path:
486-
dist_config.set_comm_init_config(self.config.distributed_config_path)
487-
else:
488-
raise Exception("Please set DISTRIBUTED_CONFIG env variable.")
489-
logger.warning(
490-
f"Use default distributed config, please set env DISTRIBUTED_CONFIG"
491-
)
492-
dist_config.set_comm_init_config(
493-
os.path.join(Dir_Path + "/config", "rank_mapping_mp{}.csv".format(self.nranks)))
494-
495-
config.set_dist_config(dist_config)
496481
self.predictor = paddle.inference.create_predictor(config)
497482
self.input_names = self.predictor.get_input_names()
498483
self.seq_lens_handle = self.predictor.get_input_handle('seq_lens_this_time')

0 commit comments

Comments
 (0)