Skip to content

Commit 00daaec

Browse files
authored
add NoRepeatNGramLogitsProcessor (#3977)
* add NoRepeatNGramLogitsProcessor * Update generation_utils.py solve invalid syntax * Update generation_utils.py * Update generation_utils.py * Update generation_utils.py add no_repeat_ngram_size=None * Update generation_utils.py fix code style * fix code style
1 parent 4d736b5 commit 00daaec

File tree

1 file changed

+69
-2
lines changed

1 file changed

+69
-2
lines changed

paddlenlp/transformers/generation_utils.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import List
1716
import inspect
1817
from abc import ABC
18+
from typing import List
1919

2020
import paddle
2121
import paddle.nn as nn
2222
import paddle.nn.functional as F
2323
from paddle.common_ops_import import convert_dtype
2424
from paddle.fluid.layers.utils import map_structure
25+
2526
from paddlenlp.utils.log import logger
2627

2728
__all__ = ["GenerationMixin"]
@@ -306,6 +307,7 @@ def get_logits_processor(
306307
num_beam_groups=1,
307308
diversity_rate=0.0,
308309
repetition_penalty=None,
310+
no_repeat_ngram_size=None,
309311
logits_processors=None,
310312
):
311313
processors = LogitsProcessorList()
@@ -320,6 +322,8 @@ def get_logits_processor(
320322
)
321323
if repetition_penalty is not None and repetition_penalty != 1.0:
322324
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
325+
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
326+
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
323327
if forced_bos_token_id is not None:
324328
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
325329
if forced_eos_token_id is not None:
@@ -503,7 +507,7 @@ def _build_faster(self, kwargs):
503507
if kwargs["num_beam_groups"] != 1:
504508
# not support for group_beam_search yet in the faster version
505509
raise AttributeError("'num_beam_groups != 1' is not supported yet in the faster version")
506-
if paddle.get_default_dtype() == "float16" and kwargs["use_fp16_decoding"] == False:
510+
if paddle.get_default_dtype() == "float16" and kwargs["use_fp16_decoding"] is False:
507511
logger.info(
508512
"Since the default dtype is float16, float16 would be used " "though 'use_fp16_decoding=False'."
509513
)
@@ -531,6 +535,7 @@ def generate(
531535
decoder_start_token_id=None,
532536
forced_bos_token_id=None,
533537
forced_eos_token_id=None,
538+
no_repeat_ngram_size=None,
534539
num_return_sequences=1,
535540
diversity_rate=0.0,
536541
use_cache=True,
@@ -729,6 +734,9 @@ def generate(
729734
if decoder_start_token_id is not None
730735
else getattr(self, "decoder_start_token_id", None)
731736
)
737+
no_repeat_ngram_size = (
738+
no_repeat_ngram_size if no_repeat_ngram_size is not None else getattr(self, "no_repeat_ngram_size", None)
739+
)
732740

733741
if getattr(self, "_faster_entry", None) is not False and use_faster:
734742
args = locals()
@@ -804,6 +812,7 @@ def generate(
804812
num_beam_groups=num_beam_groups,
805813
diversity_rate=diversity_rate,
806814
repetition_penalty=repetition_penalty,
815+
no_repeat_ngram_size=no_repeat_ngram_size,
807816
logits_processors=model_kwargs["logits_processors"]
808817
if "logits_processors" in model_kwargs
809818
and isinstance(model_kwargs["logits_processors"], LogitsProcessorList)
@@ -1337,6 +1346,64 @@ def __call__(self, input_ids, logits):
13371346
return outputs
13381347

13391348

1349+
def _get_ngrams(ngram_size, prev_input_ids, num_hypos):
1350+
generated_ngrams = [{} for _ in range(num_hypos)]
1351+
for idx in range(num_hypos):
1352+
gen_tokens = prev_input_ids[idx].tolist()
1353+
generated_ngram = generated_ngrams[idx]
1354+
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
1355+
prev_ngram_tuple = tuple(ngram[:-1])
1356+
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
1357+
return generated_ngrams
1358+
1359+
1360+
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
1361+
# Before decoding the next token, prevent decoding of ngrams that have already appeared
1362+
start_idx = cur_len + 1 - ngram_size
1363+
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
1364+
return banned_ngrams.get(ngram_idx, [])
1365+
1366+
1367+
def _calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len):
1368+
"""Copied from fairseq for no_repeat_ngram in beam_search"""
1369+
if cur_len + 1 < ngram_size:
1370+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
1371+
return [[] for _ in range(num_hypos)]
1372+
1373+
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
1374+
1375+
banned_tokens = [
1376+
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
1377+
for hypo_idx in range(num_hypos)
1378+
]
1379+
return banned_tokens
1380+
1381+
1382+
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
1383+
r"""
1384+
[`LogitsProcessor`] that enforces no repetition of n-grams. See
1385+
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
1386+
Args:
1387+
ngram_size (`int`):
1388+
All ngrams of size `ngram_size` can only occur once.
1389+
"""
1390+
1391+
def __init__(self, ngram_size):
1392+
if not isinstance(ngram_size, int) or ngram_size <= 0:
1393+
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
1394+
self.ngram_size = ngram_size
1395+
1396+
def __call__(self, input_ids, scores):
1397+
num_batch_hypotheses = scores.shape[0]
1398+
cur_len = input_ids.shape[-1]
1399+
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
1400+
1401+
for i, banned_tokens in enumerate(banned_batch_tokens):
1402+
scores[i, banned_tokens] = -float("inf")
1403+
1404+
return scores
1405+
1406+
13401407
class HammingDiversityLogitsProcessor(LogitsProcessor):
13411408
"""
13421409
This `LogitsProcessor` enforces diverse beam search. Note that this logits

0 commit comments

Comments
 (0)