1
1
#!/usr/bin/env python3
2
2
3
3
"""
4
- Copyright (c) 2020 Intel Corporation
4
+ Copyright (c) 2020-2021 Intel Corporation
5
5
Licensed under the Apache License, Version 2.0 (the "License");
6
6
you may not use this file except in compliance with the License.
7
7
You may obtain a copy of the License at
20
20
from pathlib import Path
21
21
22
22
import numpy as np
23
- from openvino .inference_engine import IECore , get_version
23
+ from openvino .runtime import Core , get_version
24
24
from tokenizers import SentencePieceBPETokenizer
25
25
26
26
log .basicConfig (format = '[ %(levelname)s ] %(message)s' , level = log .DEBUG , stream = sys .stdout )
@@ -36,8 +36,8 @@ class Translator:
36
36
tokenizer_tgt (str): path to tgt tokenizer.
37
37
"""
38
38
def __init__ (self , model_xml , model_bin , device , tokenizer_src , tokenizer_tgt , output_name ):
39
- self .model = TranslationEngine (model_xml , model_bin , device , output_name )
40
- self .max_tokens = self .model .get_max_tokens ()
39
+ self .engine = TranslationEngine (model_xml , model_bin , device , output_name )
40
+ self .max_tokens = self .engine .get_max_tokens ()
41
41
self .tokenizer_src = Tokenizer (tokenizer_src , self .max_tokens )
42
42
log .debug ('Loaded src tokenizer, max tokens: {}' .format (self .max_tokens ))
43
43
self .tokenizer_tgt = Tokenizer (tokenizer_tgt , self .max_tokens )
@@ -56,7 +56,7 @@ def __call__(self, sentence, remove_repeats=True):
56
56
tokens = self .tokenizer_src .encode (sentence )
57
57
assert len (tokens ) == self .max_tokens , "the input sentence is too long."
58
58
tokens = np .array (tokens ).reshape (1 , - 1 )
59
- translation = self .model (tokens )
59
+ translation = self .engine (tokens )
60
60
translation = self .tokenizer_tgt .decode (translation [0 ], remove_repeats )
61
61
return translation
62
62
@@ -72,25 +72,24 @@ class TranslationEngine:
72
72
def __init__ (self , model_xml , model_bin , device , output_name ):
73
73
log .info ('OpenVINO Inference Engine' )
74
74
log .info ('\t build: {}' .format (get_version ()))
75
- ie = IECore ()
75
+ core = Core ()
76
76
77
77
log .info ('Reading model {}' .format (model_xml ))
78
- self .net = ie .read_network (
79
- model = model_xml ,
80
- weights = model_bin
81
- )
82
- self .net_exec = ie .load_network (self .net , device )
78
+ self .model = core .read_model (model_xml , model_bin )
79
+ compiled_model = core .compile_model (self .model , args .device )
80
+ self .infer_request = compiled_model .create_infer_request ()
83
81
log .info ('The model {} is loaded to {}' .format (model_xml , device ))
84
- self .output_name = output_name
85
- assert self .output_name != "" , "there is not output in model"
82
+ self .input_tensor_name = "tokens"
83
+ self .output_tensor_name = output_name
84
+ self .model .output (self .output_tensor_name ) # ensure a tensor with the name exists
86
85
87
86
def get_max_tokens (self ):
88
87
""" Get maximum number of tokens that supported by model.
89
88
90
89
Returns:
91
90
max_tokens (int): maximum number of tokens;
92
91
"""
93
- return self .net . input_info [ "tokens" ]. input_data .shape [1 ]
92
+ return self .model . input ( self . input_tensor_name ) .shape [1 ]
94
93
95
94
def __call__ (self , tokens ):
96
95
""" Inference method.
@@ -101,10 +100,8 @@ def __call__(self, tokens):
101
100
Returns:
102
101
translation (np.array): translated sentence in tokenized format.
103
102
"""
104
- out = self .net_exec .infer (
105
- inputs = {"tokens" : tokens }
106
- )
107
- return out [self .output_name ]
103
+ self .infer_request .infer ({self .input_tensor_name : tokens })
104
+ return self .infer_request .get_tensor (self .output_tensor_name ).data [:]
108
105
109
106
110
107
class Tokenizer :
@@ -263,4 +260,4 @@ def sentences():
263
260
264
261
if __name__ == "__main__" :
265
262
args = build_argparser ().parse_args ()
266
- main (args )
263
+ sys . exit ( main (args ) or 0 )
0 commit comments