@@ -151,7 +151,7 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
151
151
* Apply the BOS token forcing to the logits.
152
152
* @param {bigint[][] } input_ids The input IDs.
153
153
* @param {Tensor } logits The logits.
154
- * @returns {Object } The logits with BOS token forcing.
154
+ * @returns {Tensor } The logits with BOS token forcing.
155
155
*/
156
156
_call ( input_ids , logits ) {
157
157
for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -221,7 +221,7 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
221
221
* Apply the BOS token forcing to the logits.
222
222
* @param {bigint[][] } input_ids The input IDs.
223
223
* @param {Tensor } logits The logits.
224
- * @returns {Object } The logits with BOS token forcing.
224
+ * @returns {Tensor } The logits with BOS token forcing.
225
225
*/
226
226
_call ( input_ids , logits ) {
227
227
for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -391,7 +391,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
391
391
* Apply the no-repeat-ngram processor to the logits.
392
392
* @param {bigint[][] } input_ids The input IDs.
393
393
* @param {Tensor } logits The logits.
394
- * @returns {Object } The logits with no-repeat-ngram processing.
394
+ * @returns {Tensor } The logits with no-repeat-ngram processing.
395
395
*/
396
396
_call ( input_ids , logits ) {
397
397
for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -406,12 +406,22 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
406
406
}
407
407
408
408
/**
409
- * A logits processor that penalises repeated output tokens.
409
+ * A logits processor that prevents the repetition of previous tokens through a penalty.
410
+ * This penalty is applied at most once per token. Note that, for decoder-only models like most LLMs,
411
+ * the considered tokens include the prompt.
412
+ *
413
+ * In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a
414
+ * penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition.
415
+ * To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes
416
+ * more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where
417
+ * a lower value rewards more strongly.
410
418
*/
411
419
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
412
420
/**
413
421
* Create a RepetitionPenaltyLogitsProcessor.
414
- * @param {number } penalty The penalty to apply for repeated tokens.
422
+ * @param {number } penalty The parameter for repetition penalty.
423
+ * - 1.0 means no penalty. Above 1.0 penalizes previously generated tokens.
424
+ * - Between 0.0 and 1.0 rewards previously generated tokens.
415
425
*/
416
426
constructor ( penalty ) {
417
427
super ( ) ;
@@ -422,16 +432,12 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
422
432
* Apply the repetition penalty to the logits.
423
433
* @param {bigint[][] } input_ids The input IDs.
424
434
* @param {Tensor } logits The logits.
425
- * @returns {Object } The logits with repetition penalty processing.
435
+ * @returns {Tensor } The logits with repetition penalty processing.
426
436
*/
427
437
_call ( input_ids , logits ) {
428
- // Modify the logits corresponding to each element in `input_ids`.
429
- // As a consequence, the logits corresponding to tokens that appear
430
- // many times in the output will be penalised more.
431
-
432
438
for ( let i = 0 ; i < input_ids . length ; ++ i ) {
433
439
const batch_logits_data = /** @type {Float32Array } */ ( logits [ i ] . data ) ;
434
- for ( const input_id of input_ids [ i ] ) {
440
+ for ( const input_id of new Set ( input_ids [ i ] ) ) {
435
441
const token = Number ( input_id ) ;
436
442
if ( batch_logits_data [ token ] < 0 ) {
437
443
batch_logits_data [ token ] *= this . penalty ;
@@ -464,7 +470,7 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
464
470
* Apply logit processor.
465
471
* @param {bigint[][] } input_ids The input IDs.
466
472
* @param {Tensor } logits The logits.
467
- * @returns {Object } The processed logits.
473
+ * @returns {Tensor } The processed logits.
468
474
*/
469
475
_call ( input_ids , logits ) {
470
476
for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -502,7 +508,7 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
502
508
* Apply logit processor.
503
509
* @param {bigint[][] } input_ids The input IDs.
504
510
* @param {Tensor } logits The logits.
505
- * @returns {Object } The processed logits.
511
+ * @returns {Tensor } The processed logits.
506
512
*/
507
513
_call ( input_ids , logits ) {
508
514
for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -535,7 +541,7 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
535
541
* Apply logit processor.
536
542
* @param {bigint[][] } input_ids The input IDs.
537
543
* @param {Tensor } logits The logits.
538
- * @returns {Object } The processed logits.
544
+ * @returns {Tensor } The processed logits.
539
545
*/
540
546
_call ( input_ids , logits ) {
541
547
for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -596,7 +602,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
596
602
* Apply logit processor.
597
603
* @param {bigint[][] } input_ids The input IDs.
598
604
* @param {Tensor } logits The logits.
599
- * @returns {Object } The processed logits.
605
+ * @returns {Tensor } The processed logits.
600
606
*/
601
607
_call ( input_ids , logits ) {
602
608
if ( logits . dims [ 0 ] !== 2 * input_ids . length ) {
@@ -650,7 +656,7 @@ export class TemperatureLogitsWarper extends LogitsWarper {
650
656
* Apply logit warper.
651
657
* @param {bigint[][] } input_ids The input IDs.
652
658
* @param {Tensor } logits The logits.
653
- * @returns {Object } The processed logits.
659
+ * @returns {Tensor } The processed logits.
654
660
*/
655
661
_call ( input_ids , logits ) {
656
662
const batch_logits_data = /** @type {Float32Array } */ ( logits . data ) ;
0 commit comments