13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
- from typing import List
17
16
import inspect
18
17
from abc import ABC
18
+ from typing import List
19
19
20
20
import paddle
21
21
import paddle .nn as nn
22
22
import paddle .nn .functional as F
23
23
from paddle .common_ops_import import convert_dtype
24
24
from paddle .fluid .layers .utils import map_structure
25
+
25
26
from paddlenlp .utils .log import logger
26
27
27
28
__all__ = ["GenerationMixin" ]
@@ -306,6 +307,7 @@ def get_logits_processor(
306
307
num_beam_groups = 1 ,
307
308
diversity_rate = 0.0 ,
308
309
repetition_penalty = None ,
310
+ no_repeat_ngram_size = None ,
309
311
logits_processors = None ,
310
312
):
311
313
processors = LogitsProcessorList ()
@@ -320,6 +322,8 @@ def get_logits_processor(
320
322
)
321
323
if repetition_penalty is not None and repetition_penalty != 1.0 :
322
324
processors .append (RepetitionPenaltyLogitsProcessor (penalty = repetition_penalty ))
325
+ if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0 :
326
+ processors .append (NoRepeatNGramLogitsProcessor (no_repeat_ngram_size ))
323
327
if forced_bos_token_id is not None :
324
328
processors .append (ForcedBOSTokenLogitsProcessor (forced_bos_token_id ))
325
329
if forced_eos_token_id is not None :
@@ -503,7 +507,7 @@ def _build_faster(self, kwargs):
503
507
if kwargs ["num_beam_groups" ] != 1 :
504
508
# not support for group_beam_search yet in the faster version
505
509
raise AttributeError ("'num_beam_groups != 1' is not supported yet in the faster version" )
506
- if paddle .get_default_dtype () == "float16" and kwargs ["use_fp16_decoding" ] == False :
510
+ if paddle .get_default_dtype () == "float16" and kwargs ["use_fp16_decoding" ] is False :
507
511
logger .info (
508
512
"Since the default dtype is float16, float16 would be used " "though 'use_fp16_decoding=False'."
509
513
)
@@ -531,6 +535,7 @@ def generate(
531
535
decoder_start_token_id = None ,
532
536
forced_bos_token_id = None ,
533
537
forced_eos_token_id = None ,
538
+ no_repeat_ngram_size = None ,
534
539
num_return_sequences = 1 ,
535
540
diversity_rate = 0.0 ,
536
541
use_cache = True ,
@@ -729,6 +734,9 @@ def generate(
729
734
if decoder_start_token_id is not None
730
735
else getattr (self , "decoder_start_token_id" , None )
731
736
)
737
+ no_repeat_ngram_size = (
738
+ no_repeat_ngram_size if no_repeat_ngram_size is not None else getattr (self , "no_repeat_ngram_size" , None )
739
+ )
732
740
733
741
if getattr (self , "_faster_entry" , None ) is not False and use_faster :
734
742
args = locals ()
@@ -804,6 +812,7 @@ def generate(
804
812
num_beam_groups = num_beam_groups ,
805
813
diversity_rate = diversity_rate ,
806
814
repetition_penalty = repetition_penalty ,
815
+ no_repeat_ngram_size = no_repeat_ngram_size ,
807
816
logits_processors = model_kwargs ["logits_processors" ]
808
817
if "logits_processors" in model_kwargs
809
818
and isinstance (model_kwargs ["logits_processors" ], LogitsProcessorList )
@@ -1337,6 +1346,64 @@ def __call__(self, input_ids, logits):
1337
1346
return outputs
1338
1347
1339
1348
1349
+ def _get_ngrams (ngram_size , prev_input_ids , num_hypos ):
1350
+ generated_ngrams = [{} for _ in range (num_hypos )]
1351
+ for idx in range (num_hypos ):
1352
+ gen_tokens = prev_input_ids [idx ].tolist ()
1353
+ generated_ngram = generated_ngrams [idx ]
1354
+ for ngram in zip (* [gen_tokens [i :] for i in range (ngram_size )]):
1355
+ prev_ngram_tuple = tuple (ngram [:- 1 ])
1356
+ generated_ngram [prev_ngram_tuple ] = generated_ngram .get (prev_ngram_tuple , []) + [ngram [- 1 ]]
1357
+ return generated_ngrams
1358
+
1359
+
1360
+ def _get_generated_ngrams (banned_ngrams , prev_input_ids , ngram_size , cur_len ):
1361
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
1362
+ start_idx = cur_len + 1 - ngram_size
1363
+ ngram_idx = tuple (prev_input_ids [start_idx :cur_len ].tolist ())
1364
+ return banned_ngrams .get (ngram_idx , [])
1365
+
1366
+
1367
+ def _calc_banned_ngram_tokens (ngram_size , prev_input_ids , num_hypos , cur_len ):
1368
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
1369
+ if cur_len + 1 < ngram_size :
1370
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
1371
+ return [[] for _ in range (num_hypos )]
1372
+
1373
+ generated_ngrams = _get_ngrams (ngram_size , prev_input_ids , num_hypos )
1374
+
1375
+ banned_tokens = [
1376
+ _get_generated_ngrams (generated_ngrams [hypo_idx ], prev_input_ids [hypo_idx ], ngram_size , cur_len )
1377
+ for hypo_idx in range (num_hypos )
1378
+ ]
1379
+ return banned_tokens
1380
+
1381
+
1382
+ class NoRepeatNGramLogitsProcessor (LogitsProcessor ):
1383
+ r"""
1384
+ [`LogitsProcessor`] that enforces no repetition of n-grams. See
1385
+ [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
1386
+ Args:
1387
+ ngram_size (`int`):
1388
+ All ngrams of size `ngram_size` can only occur once.
1389
+ """
1390
+
1391
+ def __init__ (self , ngram_size ):
1392
+ if not isinstance (ngram_size , int ) or ngram_size <= 0 :
1393
+ raise ValueError (f"`ngram_size` has to be a strictly positive integer, but is { ngram_size } " )
1394
+ self .ngram_size = ngram_size
1395
+
1396
+ def __call__ (self , input_ids , scores ):
1397
+ num_batch_hypotheses = scores .shape [0 ]
1398
+ cur_len = input_ids .shape [- 1 ]
1399
+ banned_batch_tokens = _calc_banned_ngram_tokens (self .ngram_size , input_ids , num_batch_hypotheses , cur_len )
1400
+
1401
+ for i , banned_tokens in enumerate (banned_batch_tokens ):
1402
+ scores [i , banned_tokens ] = - float ("inf" )
1403
+
1404
+ return scores
1405
+
1406
+
1340
1407
class HammingDiversityLogitsProcessor (LogitsProcessor ):
1341
1408
"""
1342
1409
This `LogitsProcessor` enforces diverse beam search. Note that this logits
0 commit comments