Skip to content

Commit 082dc52

Browse files
authored
[Improvement] fix ops improting in utils (PaddlePaddle#7865)
* fix ops improting in utils * update the reset_stop_value importing * update the error msg of import paddlenlp_ops * rename error -> warning
1 parent 283c535 commit 082dc52

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

llm/predictor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import paddle.distributed.fleet.base.topology as tp
2828
import paddle.incubate.multiprocessing as mp
2929
from paddle.distributed import fleet
30-
from paddlenlp_ops import reset_stop_value
3130
from utils import (
3231
dybatch_preprocess,
3332
get_alibi_slopes,
@@ -57,6 +56,16 @@
5756
from paddlenlp.utils.import_utils import import_module, is_paddlenlp_ops_available
5857
from paddlenlp.utils.log import logger
5958

59+
try:
60+
from paddlenlp_ops import reset_stop_value
61+
except (ImportError, ModuleNotFoundError):
62+
logger.warning(
63+
"if you run predictor.py with --inference_model argument, please ensure you install "
64+
"the paddlenlp_ops by following the instructions "
65+
"provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
66+
)
67+
68+
6069
# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output
6170
MAX_BSZ = 512
6271

llm/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import paddle.incubate.multiprocessing as mp
2626
from paddle.distributed import fleet
2727
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
28-
from paddlenlp_ops import get_output
2928
from sklearn.metrics import accuracy_score
3029

3130
from paddlenlp.datasets import InTokensIterableDataset
@@ -704,6 +703,9 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q
704703

705704
logger.info("Start read result message")
706705
logger.info(f"Current path is {os.getcwd()}")
706+
707+
from paddlenlp_ops import get_output
708+
707709
while True:
708710
get_output(output_tensor, 0, True)
709711
if output_tensor[0, 0] == -2: # read none

0 commit comments

Comments
 (0)