Skip to content

Commit f67162b

Browse files
authored
AC: open nmt evaluator for api 2.0 (#2950)
1 parent 87ffa03 commit f67162b

File tree

1 file changed

+60
-3
lines changed

1 file changed

+60
-3
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/opennmt_encoder_decoder_generator_evaluator.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import numpy as np
2020

2121
from .base_custom_evaluator import BaseCustomEvaluator
22-
from .base_models import BaseCascadeModel, BaseDLSDKModel, create_model, BaseONNXModel
22+
from .base_models import BaseCascadeModel, BaseDLSDKModel, create_model, BaseONNXModel, BaseOpenVINOModel
2323
from ...adapters import create_adapter
2424
from ...config import ConfigError
25-
from ...utils import contains_all, contains_any, extract_image_representations
25+
from ...utils import contains_all, contains_any, extract_image_representations, parse_partial_shape
2626

2727

2828
class OpenNMTEvaluator(BaseCustomEvaluator):
@@ -78,15 +78,18 @@ def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_l
7878
self._encoder_mapping = {
7979
'dlsdk': EncoderDLSDKModel,
8080
'onnx_runtime': EncoderONNXModel,
81+
'openvino': EncoderOVModel
8182
}
8283
self._decoder_mapping = {
8384
'dlsdk': DecoderDLSDKModel,
8485
'onnx_runtime': DecoderONNXModel,
86+
'openvino': DecoderOVModel
8587
}
8688

8789
self._generator_mapping = {
8890
'dlsdk': GeneratorDLSDKModel,
89-
'onnx_runtime': GeneratorONNXModel
91+
'onnx_runtime': GeneratorONNXModel,
92+
'openvino': GeneratorOVModel
9093
}
9194

9295
self.encoder = create_model(network_info['encoder'], launcher, self._encoder_mapping, 'encoder',
@@ -215,6 +218,38 @@ def propagate_output(self, data):
215218
pass
216219

217220

221+
class CommonOVModel(BaseOpenVINOModel):
222+
default_model_suffix = 'encoder'
223+
input_layers = []
224+
output_layers = []
225+
return_layers = []
226+
227+
def predict(self, identifiers, input_data, callback=None):
228+
input_data = self.fit_to_input(input_data)
229+
results = self.infer(input_data)
230+
self.propagate_output(results)
231+
names = self.return_layers if len(self.return_layers) > 0 else self.output_layers
232+
return tuple(results[name] for name in names) + (results,)
233+
234+
def fit_to_input(self, input_data):
235+
if isinstance(input_data, dict):
236+
fitted = {}
237+
for input_blob in self.inputs:
238+
fitted.update(self.fit_one_input(input_blob, input_data[input_blob]))
239+
else:
240+
fitted = self.fit_one_input(self.input_blob, input_data)
241+
return fitted
242+
243+
def fit_one_input(self, input_blob, input_data):
244+
if input_blob in self.dynamic_inputs or parse_partial_shape(self.inputs[input_blob]) != np.shape(input_data):
245+
self._reshape_input({input_blob: np.shape(input_data)})
246+
247+
return {input_blob: np.array(input_data)}
248+
249+
def propagate_output(self, data):
250+
pass
251+
252+
218253
class BeamSearch:
219254
def __init__(self, config):
220255
self.batch_size = config.get('batch', 1)
@@ -399,6 +434,28 @@ class GeneratorDLSDKModel(CommonDLSDKModel):
399434
output_layers = ['output']
400435

401436

437+
class EncoderOVModel(CommonOVModel):
438+
default_model_suffix = 'encoder'
439+
input_layers = ['src', 'src_len']
440+
output_layers = ['state.0/sink_port_0', 'state.1/sink_port_0', 'memory/sink_port_0']
441+
return_layers = ['state.0/sink_port_0', 'state.1/sink_port_0', 'memory/sink_port_0']
442+
443+
444+
class DecoderOVModel(CommonOpenNMTDecoder, CommonOVModel):
445+
default_model_suffix = 'decoder'
446+
input_layers = ['c_0', 'h_0', 'input', 'input_feed.1', 'mem_len', 'memory']
447+
output_layers = ['attn', 'c_1', 'h_1', 'input_feed', 'output']
448+
return_layers = ['output', 'attn']
449+
state_inputs = ['h_0', 'c_0', 'memory', 'mem_len', 'input_feed.1']
450+
state_outputs = ['h_1', 'c_1', '', '', 'input_feed']
451+
452+
453+
class GeneratorOVModel(CommonOVModel):
454+
default_model_suffix = 'generator'
455+
input_layers = ['input']
456+
output_layers = ['output/sink_port_0']
457+
458+
402459
class CommonONNXModel(BaseONNXModel):
403460
default_model_suffix = 'encoder'
404461
input_layers = []

0 commit comments

Comments
 (0)