File tree Expand file tree Collapse file tree 2 files changed +10
-10
lines changed Expand file tree Collapse file tree 2 files changed +10
-10
lines changed Original file line number Diff line number Diff line change 16
16
import time
17
17
import argparse
18
18
import json
19
+ from decimal import Decimal
19
20
import numpy as np
20
21
from paddlenlp .utils .log import logger
21
22
@@ -35,7 +36,11 @@ def do_convert():
35
36
if len (args .splits ) != 0 and len (args .splits ) != 3 :
36
37
raise ValueError ("Only []/ len(splits)==3 accepted for splits." )
37
38
38
- if args .splits and sum (args .splits ) != 1 :
39
+ def _check_sum (splits ):
40
+ return Decimal (str (splits [0 ])) + Decimal (str (splits [1 ])) + Decimal (
41
+ str (splits [2 ])) == Decimal ("1" )
42
+
43
+ if len (args .splits ) == 3 and not _check_sum (args .splits ):
39
44
raise ValueError (
40
45
"Please set correct splits, sum of elements in splits should be equal to 1."
41
46
)
Original file line number Diff line number Diff line change @@ -208,15 +208,10 @@ def _prepare_onnx_mode(self):
208
208
sess_options .inter_op_num_threads = self ._num_threads
209
209
self .predictor = ort .InferenceSession (
210
210
fp16_model_file , sess_options = sess_options , providers = providers )
211
- try :
212
- assert 'CUDAExecutionProvider' in self .predictor .get_providers ()
213
- except AssertionError :
214
- raise AssertionError (
215
- f"The environment for GPU inference is not set properly. "
216
- "A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. "
217
- "Please run the following commands to reinstall: \n "
218
- "1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu"
219
- )
211
+ assert 'CUDAExecutionProvider' in self .predictor .get_providers (), f"The environment for GPU inference is not set properly. " \
212
+ "A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. " \
213
+ "Please run the following commands to reinstall: \n " \
214
+ "1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu"
220
215
221
216
def _get_inference_model (self ):
222
217
"""
You can’t perform that action at this time.
0 commit comments