Skip to content

Commit f3059a5

Browse files
PC91Thai Chau Truong
andauthored
Add script for n_best parameter in topp/topk (#2509)
* Add script for n_best parameter in topp/topk Co-authored-by: Thai Chau Truong <tctruong@dom_softissimo.lan>
1 parent c5c84af commit f3059a5

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

onmt/tests/test_greedy_search.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
4646
2,
4747
3,
4848
1,
49+
1,
4950
batch_sz,
5051
GlobalScorerStub(),
5152
min_length,
@@ -100,6 +101,7 @@ def test_returns_correct_scores_deterministic(self):
100101
2,
101102
3,
102103
1,
104+
1,
103105
batch_sz,
104106
GlobalScorerStub(),
105107
0,
@@ -186,6 +188,7 @@ def test_returns_correct_scores_non_deterministic(self):
186188
2,
187189
3,
188190
1,
191+
1,
189192
batch_sz,
190193
GlobalScorerStub(),
191194
0,
@@ -297,6 +300,7 @@ def test_returns_correct_scores_non_deterministic_beams(self):
297300
2,
298301
3,
299302
1,
303+
1,
300304
batch_sz,
301305
GlobalScorerStub(),
302306
0,
@@ -374,7 +378,7 @@ def test_returns_correct_scores_non_deterministic_beams(self):
374378

375379
samp.update_finished()
376380
self.assertEqual(
377-
[score for score, _, _ in samp.hypotheses[batch_sz - 1][-1:]],
381+
[score for score, _, _ in samp.hypotheses[batch_sz - 1][:1]],
378382
[valid_score_dist_2[0] / temp],
379383
)
380384

@@ -419,6 +423,7 @@ def test_returns_correct_scores_non_deterministic_topp(self):
419423
2,
420424
3,
421425
1,
426+
1,
422427
batch_sz,
423428
GlobalScorerStub(),
424429
0,

onmt/translate/greedy_search.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class GreedySearch(DecodeStrategy):
9898
eos (int): See base.
9999
unk (int): See base.
100100
start (int): See base.
101+
n_best (int): Don't stop until at least this many beams have
102+
reached EOS.
101103
batch_size (int): See base.
102104
global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
103105
min_length (int): See base.
@@ -123,6 +125,7 @@ def __init__(
123125
eos,
124126
unk,
125127
start,
128+
n_best,
126129
batch_size,
127130
global_scorer,
128131
min_length,
@@ -157,6 +160,7 @@ def __init__(
157160
self.keep_topp = keep_topp
158161
self.topk_scores = None
159162
self.beam_size = beam_size
163+
self.n_best = n_best
160164

161165
def initialize(
162166
self, enc_out, src_len, src_map=None, device=None, target_prefix=None
@@ -265,10 +269,14 @@ def update_finished(self):
265269
else []
266270
)
267271
self.hypotheses[b_orig].append((score, pred, attention))
272+
if len(self.hypotheses[b_orig]) >= 2:
273+
self.hypotheses[b_orig] = sorted(
274+
self.hypotheses[b_orig], key=lambda x: x[0], reverse=True
275+
)
268276
self.done = self.is_finished.all()
269277
if self.done:
270278
for b in range(self.batch_size):
271-
best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)
279+
best_hyp = self.hypotheses[b][: self.n_best]
272280
for score, pred, attn in best_hyp:
273281
self.scores[b].append(score)
274282
self.predictions[b].append(pred)

onmt/translate/translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,7 @@ def translate_batch(self, batch, attn_debug):
810810
eos=self._tgt_eos_idx,
811811
unk=self._tgt_unk_idx,
812812
start=self._tgt_start_with,
813+
n_best=self.n_best,
813814
batch_size=len(batch["srclen"]),
814815
global_scorer=self.global_scorer,
815816
min_length=self.min_length,
@@ -1009,6 +1010,7 @@ def translate_batch(self, batch, attn_debug):
10091010
eos=self._tgt_eos_idx,
10101011
unk=self._tgt_unk_idx,
10111012
start=self._tgt_start_with,
1013+
n_best=self.n_best,
10121014
batch_size=len(batch["srclen"]),
10131015
global_scorer=self.global_scorer,
10141016
min_length=self.min_length,

0 commit comments

Comments
 (0)