Skip to content

Commit e9b5cfd

Browse files
authored
AC: fix custom evaluator for custom asr model (#3423)
* AC: fix custom evaluator for custom asr model * more fixes
1 parent d9a1fa1 commit e9b5cfd

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/asr_custom_encoder_decoder_joint.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def predict(self, identifiers, input_data):
6464
outputs, raw_outputs = outputs
6565
else:
6666
raw_outputs = outputs
67-
encoder_output = np.array(outputs[self.encoder_out]).squeeze()
67+
encoder_output = outputs[self.encoder_out]
6868
self.h0 = outputs[self.h0_out]
6969
self.c0 = outputs[self.c0_out]
70-
return encoder_output, raw_outputs
70+
return encoder_output.squeeze(), raw_outputs
7171

7272
def fit_to_input(self, input_data):
7373
return {self.input: input_data, self.h0_input: self.h0, self.c0_input: self.c0}
@@ -127,7 +127,7 @@ def predict(self, identifiers, input_data, hidden=None):
127127
raw_outputs = outputs
128128
self.h0 = outputs[self.h0_out]
129129
self.c0 = outputs[self.c0_out]
130-
return np.array(outputs[self.decoder_out]).squeeze(), (self.h0, self.c0), raw_outputs
130+
return outputs[self.decoder_out].squeeze(), (self.h0, self.c0), raw_outputs
131131

132132
def fit_to_input(self, token_id, hidden):
133133
if hidden is None:
@@ -189,7 +189,7 @@ def predict(self, identifiers, input_data):
189189
else:
190190
raw_outputs = outputs
191191
joint_out = outputs[self.output]
192-
return log_softmax(np.array(joint_out).squeeze()), raw_outputs
192+
return log_softmax(joint_out), raw_outputs
193193

194194
def fit_to_input(self, encoder_out, predictor_out):
195195
return {self.input1: encoder_out, self.input2: predictor_out}
@@ -339,7 +339,7 @@ def __init__(self, network_info, launcher, suffix=None, delayed_model_loading=Fa
339339
class OVJoint(Joint, CommonOpenVINOModel):
340340
def __init__(self, network_info, launcher, suffix=None, delayed_model_loading=False):
341341
self.default_inputs = ['0', '1']
342-
self.default_outputs = ['8/sink_port']
342+
self.default_outputs = ['8/sink_port_0']
343343
super().__init__(network_info, launcher, suffix, delayed_model_loading)
344344

345345

@@ -353,25 +353,22 @@ def infer(self, input_data):
353353
results = self.inference_session.run(self.output_names, input_data)
354354
return dict(zip(self.output_names, results))
355355

356-
def select_inputs_outputs(self, network_info):
357-
pass
358-
359356

360-
class ONNXEncoder(CommonONNXModel, Encoder):
357+
class ONNXEncoder(Encoder, CommonONNXModel):
361358
def __init__(self, network_info, launcher, suffix=None, delayed_model_loading=False):
362359
self.default_inputs = ['input_0', 'input_1', 'input_2']
363360
self.default_outputs = ['output_0', 'output_1', 'output_2']
364361
super().__init__(network_info, launcher, suffix, delayed_model_loading)
365362

366363

367-
class ONNXDecoder(CommonONNXModel, Decoder):
364+
class ONNXDecoder(Decoder, CommonONNXModel):
368365
def __init__(self, network_info, launcher, suffix=None, delayed_model_loading=False):
369366
self.default_inputs = ['input_0', 'input_1', 'input_2']
370367
self.default_outputs = ['output_0', 'output_1', 'output_2']
371368
super().__init__(network_info, launcher, suffix, delayed_model_loading)
372369

373370

374-
class ONNXJoint(CommonONNXModel, Joint):
371+
class ONNXJoint(Joint, CommonONNXModel):
375372
def __init__(self, network_info, launcher, suffix=None, delayed_model_loading=False):
376373
self.default_inputs = ['0', '1']
377374
self.default_outputs = ['8']
@@ -454,7 +451,7 @@ def predict(self, identifiers, input_data, encoder_callback=None):
454451
if len(B) >= self.beam_width and yb.log_prob >= y_hat.log_prob:
455452
break
456453
B = heapq.nlargest(self.beam_width, B)
457-
return self.adapter.process([B[0].sequence], identifiers, [{}]), {}
454+
return [{}], self.adapter.process([B[0].sequence], identifiers, [{}])
458455

459456
@staticmethod
460457
def prepare_records(features):

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/base_models.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,15 @@ def automatic_model_search(self, network_info):
169169
if not model_list:
170170
model_list = list(model.glob('*.blob'))
171171
else:
172-
model_list = list(model.glob('*{}.xml'.format(self.default_model_suffix)))
173-
blob_list = list(model.glob('*{}.blob'.format(self.default_model_suffix)))
174-
if not model_list and not blob_list:
172+
model_list = list(model.glob('*{}*.xml'.format(self.default_model_suffix)))
173+
blob_list = list(model.glob('*{}*.blob'.format(self.default_model_suffix)))
174+
onnx_list = list(model.glob('*{}*.onnx'.format(self.default_model_suffix)))
175+
if not model_list and not blob_list and not onnx_list:
175176
model_list = list(model.glob('*.xml'))
176177
blob_list = list(model.glob('*.blob'))
177-
if not model_list:
178-
model_list = blob_list
178+
onnx_list = list(model.glob('*.onnx'))
179+
if not model_list:
180+
model_list = blob_list if blob_list else onnx_list
179181
if not model_list:
180182
raise ConfigError('Suitable model for {} not found'.format(self.default_model_suffix))
181183
if len(model_list) > 1:
@@ -391,7 +393,7 @@ def release(self):
391393
def automatic_model_search(self, network_info):
392394
model = Path(network_info['model'])
393395
if model.is_dir():
394-
model_list = list(model.glob('*{}.onnx'.format(self.default_model_suffix)))
396+
model_list = list(model.glob('*{}*.onnx'.format(self.default_model_suffix)))
395397
if not model_list:
396398
model_list = list(model.glob('*.onnx'))
397399
if not model_list:

0 commit comments

Comments
 (0)