|
25 | 25 | import paddle |
26 | 26 | import paddle.distributed as dist |
27 | 27 | import paddle.distributed.fleet as fleet |
| 28 | +from paddle.base.framework import use_pir_api |
28 | 29 | from paddlenlp.trl.llm_utils import get_rotary_position_embedding |
29 | 30 | from paddlenlp_ops import step_paddle |
30 | 31 | from server.data.processor import DataProcessor |
@@ -467,32 +468,16 @@ def _init_predictor(self): |
467 | 468 | predictor init |
468 | 469 | """ |
469 | 470 | device_id = self.rank % 8 |
470 | | - self.model_file = os.path.join(self.model_dir, f"model.pdmodel") |
471 | | - self.param_file = os.path.join(self.model_dir, f"model.pdiparams") |
| 471 | + if use_pir_api(): |
| 472 | + self.model_file = os.path.join(self.model_dir, f"model.json") |
| 473 | + self.param_file = os.path.join(self.model_dir, f"model.pdiparams") |
| 474 | + else: |
| 475 | + self.model_file = os.path.join(self.model_dir, f"model.pdmodel") |
| 476 | + self.param_file = os.path.join(self.model_dir, f"model.pdiparams") |
472 | 477 | config = paddle.inference.Config(self.model_file, self.param_file) |
473 | 478 |
|
474 | | - config.switch_ir_optim(False) |
475 | 479 | config.enable_use_gpu(100, device_id) |
476 | 480 |
|
477 | | - # distributed config |
478 | | - if self.mp_degree > 1: |
479 | | - trainer_endpoints = fleet.worker_endpoints() |
480 | | - current_endpoint = trainer_endpoints[self.rank] |
481 | | - dist_config = config.dist_config() |
482 | | - dist_config.set_ranks(self.nranks, self.rank) |
483 | | - dist_config.set_endpoints(trainer_endpoints, current_endpoint) |
484 | | - dist_config.enable_dist_model(True) |
485 | | - if self.config.distributed_config_path: |
486 | | - dist_config.set_comm_init_config(self.config.distributed_config_path) |
487 | | - else: |
488 | | - raise Exception("Please set DISTRIBUTED_CONFIG env variable.") |
489 | | - logger.warning( |
490 | | - f"Use default distributed config, please set env DISTRIBUTED_CONFIG" |
491 | | - ) |
492 | | - dist_config.set_comm_init_config( |
493 | | - os.path.join(Dir_Path + "/config", "rank_mapping_mp{}.csv".format(self.nranks))) |
494 | | - |
495 | | - config.set_dist_config(dist_config) |
496 | 481 | self.predictor = paddle.inference.create_predictor(config) |
497 | 482 | self.input_names = self.predictor.get_input_names() |
498 | 483 | self.seq_lens_handle = self.predictor.get_input_handle('seq_lens_this_time') |
|
0 commit comments