12
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
- import os
16
-
17
15
import argparse
18
- import numpy as np
19
16
from functools import partial
20
17
21
18
import paddle
22
- from paddle import inference
23
- from paddlenlp .data import Stack , Tuple , Pad , Vocab
24
- from paddlenlp .transformers import ErnieTokenizer
25
-
26
19
from utils import convert_example , parse_decode
27
20
28
- # yapf: disable
21
+ from paddlenlp .data import Pad , Stack , Tuple , Vocab
22
+ from paddlenlp .transformers import ErnieTokenizer
23
+
29
24
parser = argparse .ArgumentParser (__doc__ )
30
- parser .add_argument ("--model_file" , type = str , required = True , default = './static_graph_params.pdmodel' , help = "The path to model info in static graph." )
31
- parser .add_argument ("--params_file" , type = str , required = True , default = './static_graph_params.pdiparams' , help = "The path to parameters in static graph." )
25
+ parser .add_argument (
26
+ "--model_file" ,
27
+ type = str ,
28
+ required = True ,
29
+ default = "./static_graph_params.pdmodel" ,
30
+ help = "The path to model info in static graph." ,
31
+ )
32
+ parser .add_argument (
33
+ "--params_file" ,
34
+ type = str ,
35
+ required = True ,
36
+ default = "./static_graph_params.pdiparams" ,
37
+ help = "The path to parameters in static graph." ,
38
+ )
32
39
parser .add_argument ("--batch_size" , type = int , default = 2 , help = "The number of sequences contained in a mini-batch." )
33
40
parser .add_argument ("--max_seq_len" , type = int , default = 64 , help = "Number of words of the longest seqence." )
34
- parser .add_argument ("--device" , default = "gpu" , type = str , choices = ["cpu" , "gpu" ] ,help = "The device to select to train the model, is must be cpu/gpu." )
41
+ parser .add_argument (
42
+ "--device" ,
43
+ default = "gpu" ,
44
+ type = str ,
45
+ choices = ["cpu" , "gpu" ],
46
+ help = "The device to select to train the model, is must be cpu/gpu." ,
47
+ )
35
48
parser .add_argument ("--pinyin_vocab_file_path" , type = str , default = "pinyin_vocab.txt" , help = "pinyin vocab file path" )
36
49
37
50
args = parser .parse_args ()
38
- # yapf: enable
39
51
40
52
41
53
class Predictor (object ):
@@ -51,6 +63,7 @@ def __init__(self, model_file, params_file, device, max_seq_length, tokenizer, p
51
63
# such as enable_mkldnn, set_cpu_math_library_num_threads
52
64
config .disable_gpu ()
53
65
config .switch_use_feed_fetch_ops (False )
66
+ config .delete_pass ("fused_multi_transformer_encoder_pass" )
54
67
self .predictor = paddle .inference .create_predictor (config )
55
68
56
69
self .input_handles = [self .predictor .get_input_handle (name ) for name in self .predictor .get_input_names ()]
0 commit comments