Skip to content

Commit 2501256

Browse files
author
Anna Grebneva
authored
Fixed opennmt evaluator in case running with not default beams (#3266)
1 parent 9825202 commit 2501256

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/opennmt_encoder_decoder_generator_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def predict(self, identifiers, input_data, encoder_callback=None):
118118
if encoder_callback:
119119
encoder_callback(raw_outputs)
120120

121-
log_probs, raw_outputs = self.generator.predict(identifiers, {'input': decoder_output.squeeze()})
121+
log_probs, raw_outputs = self.generator.predict(identifiers, {'input': decoder_output.squeeze(axis=0)})
122122
if encoder_callback:
123123
encoder_callback(raw_outputs)
124124

@@ -313,7 +313,7 @@ def _pick(self, log_probs):
313313
curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
314314
topk_ids = np.argsort(curr_scores)[..., range(self.beam_size * vocab_size - 1,
315315
self.beam_size * (vocab_size - 1) - 1, -1)]
316-
topk_scores = curr_scores[..., topk_ids.squeeze()]
316+
topk_scores = curr_scores[..., topk_ids.squeeze(axis=0)]
317317
return topk_scores, topk_ids
318318

319319
def update_finished(self):
@@ -373,7 +373,7 @@ def advance(self, log_probs):
373373
self.topk_ids = np.fmod(self.topk_ids, vocab_size) # resolve true word ids
374374

375375
self.alive_seq = np.concatenate(
376-
[np.take(self.alive_seq, self.select_indices.squeeze(), 0),
376+
[np.take(self.alive_seq, self.select_indices.squeeze(axis=-1), 0),
377377
self.topk_ids.view().reshape((_B * self.beam_size, 1))], axis=-1)
378378

379379
self.is_finished = np.equal(self.topk_ids, self.eos)

0 commit comments

Comments
 (0)