Skip to content

Commit c2796b9

Browse files
ov 2.0: python machine_translation_demo (#3031)
1 parent 5df25a5 commit c2796b9

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

demos/machine_translation_demo/python/machine_translation_demo.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
"""
4-
Copyright (c) 2020 Intel Corporation
4+
Copyright (c) 2020-2021 Intel Corporation
55
Licensed under the Apache License, Version 2.0 (the "License");
66
you may not use this file except in compliance with the License.
77
You may obtain a copy of the License at
@@ -20,7 +20,7 @@
2020
from pathlib import Path
2121

2222
import numpy as np
23-
from openvino.inference_engine import IECore, get_version
23+
from openvino.runtime import Core, get_version
2424
from tokenizers import SentencePieceBPETokenizer
2525

2626
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.DEBUG, stream=sys.stdout)
@@ -36,8 +36,8 @@ class Translator:
3636
tokenizer_tgt (str): path to tgt tokenizer.
3737
"""
3838
def __init__(self, model_xml, model_bin, device, tokenizer_src, tokenizer_tgt, output_name):
39-
self.model = TranslationEngine(model_xml, model_bin, device, output_name)
40-
self.max_tokens = self.model.get_max_tokens()
39+
self.engine = TranslationEngine(model_xml, model_bin, device, output_name)
40+
self.max_tokens = self.engine.get_max_tokens()
4141
self.tokenizer_src = Tokenizer(tokenizer_src, self.max_tokens)
4242
log.debug('Loaded src tokenizer, max tokens: {}'.format(self.max_tokens))
4343
self.tokenizer_tgt = Tokenizer(tokenizer_tgt, self.max_tokens)
@@ -56,7 +56,7 @@ def __call__(self, sentence, remove_repeats=True):
5656
tokens = self.tokenizer_src.encode(sentence)
5757
assert len(tokens) == self.max_tokens, "the input sentence is too long."
5858
tokens = np.array(tokens).reshape(1, -1)
59-
translation = self.model(tokens)
59+
translation = self.engine(tokens)
6060
translation = self.tokenizer_tgt.decode(translation[0], remove_repeats)
6161
return translation
6262

@@ -72,25 +72,24 @@ class TranslationEngine:
7272
def __init__(self, model_xml, model_bin, device, output_name):
7373
log.info('OpenVINO Inference Engine')
7474
log.info('\tbuild: {}'.format(get_version()))
75-
ie = IECore()
75+
core = Core()
7676

7777
log.info('Reading model {}'.format(model_xml))
78-
self.net = ie.read_network(
79-
model=model_xml,
80-
weights=model_bin
81-
)
82-
self.net_exec = ie.load_network(self.net, device)
78+
self.model = core.read_model(model_xml, model_bin)
79+
compiled_model = core.compile_model(self.model, args.device)
80+
self.infer_request = compiled_model.create_infer_request()
8381
log.info('The model {} is loaded to {}'.format(model_xml, device))
84-
self.output_name = output_name
85-
assert self.output_name != "", "there is not output in model"
82+
self.input_tensor_name = "tokens"
83+
self.output_tensor_name = output_name
84+
self.model.output(self.output_tensor_name) # ensure a tensor with the name exists
8685

8786
def get_max_tokens(self):
8887
""" Get maximum number of tokens that supported by model.
8988
9089
Returns:
9190
max_tokens (int): maximum number of tokens;
9291
"""
93-
return self.net.input_info["tokens"].input_data.shape[1]
92+
return self.model.input(self.input_tensor_name).shape[1]
9493

9594
def __call__(self, tokens):
9695
""" Inference method.
@@ -101,10 +100,8 @@ def __call__(self, tokens):
101100
Returns:
102101
translation (np.array): translated sentence in tokenized format.
103102
"""
104-
out = self.net_exec.infer(
105-
inputs={"tokens": tokens}
106-
)
107-
return out[self.output_name]
103+
self.infer_request.infer({self.input_tensor_name: tokens})
104+
return self.infer_request.get_tensor(self.output_tensor_name).data[:]
108105

109106

110107
class Tokenizer:
@@ -263,4 +260,4 @@ def sentences():
263260

264261
if __name__ == "__main__":
265262
args = build_argparser().parse_args()
266-
main(args)
263+
sys.exit(main(args) or 0)

0 commit comments

Comments
 (0)