Skip to content

Commit 9584263

Browse files
authored
Fix repetition penalty logits processor (#1062)
* Fix repetition penalty logits processor * Fix return types of logits processors
1 parent 2c92943 commit 9584263

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

src/generation/logits_process.js

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
151151
* Apply the BOS token forcing to the logits.
152152
* @param {bigint[][]} input_ids The input IDs.
153153
* @param {Tensor} logits The logits.
154-
* @returns {Object} The logits with BOS token forcing.
154+
* @returns {Tensor} The logits with BOS token forcing.
155155
*/
156156
_call(input_ids, logits) {
157157
for (let i = 0; i < input_ids.length; ++i) {
@@ -221,7 +221,7 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
221221
* Apply the BOS token forcing to the logits.
222222
* @param {bigint[][]} input_ids The input IDs.
223223
* @param {Tensor} logits The logits.
224-
* @returns {Object} The logits with BOS token forcing.
224+
* @returns {Tensor} The logits with BOS token forcing.
225225
*/
226226
_call(input_ids, logits) {
227227
for (let i = 0; i < input_ids.length; ++i) {
@@ -391,7 +391,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
391391
* Apply the no-repeat-ngram processor to the logits.
392392
* @param {bigint[][]} input_ids The input IDs.
393393
* @param {Tensor} logits The logits.
394-
* @returns {Object} The logits with no-repeat-ngram processing.
394+
* @returns {Tensor} The logits with no-repeat-ngram processing.
395395
*/
396396
_call(input_ids, logits) {
397397
for (let i = 0; i < input_ids.length; ++i) {
@@ -406,12 +406,22 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
406406
}
407407

408408
/**
409-
* A logits processor that penalises repeated output tokens.
409+
* A logits processor that prevents the repetition of previous tokens through a penalty.
410+
* This penalty is applied at most once per token. Note that, for decoder-only models like most LLMs,
411+
* the considered tokens include the prompt.
412+
*
413+
* In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a
414+
* penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition.
415+
* To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes
416+
* more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where
417+
* a lower value rewards more strongly.
410418
*/
411419
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
412420
/**
413421
* Create a RepetitionPenaltyLogitsProcessor.
414-
* @param {number} penalty The penalty to apply for repeated tokens.
422+
* @param {number} penalty The parameter for repetition penalty.
423+
* - 1.0 means no penalty. Above 1.0 penalizes previously generated tokens.
424+
* - Between 0.0 and 1.0 rewards previously generated tokens.
415425
*/
416426
constructor(penalty) {
417427
super();
@@ -422,16 +432,12 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
422432
* Apply the repetition penalty to the logits.
423433
* @param {bigint[][]} input_ids The input IDs.
424434
* @param {Tensor} logits The logits.
425-
* @returns {Object} The logits with repetition penalty processing.
435+
* @returns {Tensor} The logits with repetition penalty processing.
426436
*/
427437
_call(input_ids, logits) {
428-
// Modify the logits corresponding to each element in `input_ids`.
429-
// As a consequence, the logits corresponding to tokens that appear
430-
// many times in the output will be penalised more.
431-
432438
for (let i = 0; i < input_ids.length; ++i) {
433439
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
434-
for (const input_id of input_ids[i]) {
440+
for (const input_id of new Set(input_ids[i])) {
435441
const token = Number(input_id);
436442
if (batch_logits_data[token] < 0) {
437443
batch_logits_data[token] *= this.penalty;
@@ -464,7 +470,7 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
464470
* Apply logit processor.
465471
* @param {bigint[][]} input_ids The input IDs.
466472
* @param {Tensor} logits The logits.
467-
* @returns {Object} The processed logits.
473+
* @returns {Tensor} The processed logits.
468474
*/
469475
_call(input_ids, logits) {
470476
for (let i = 0; i < input_ids.length; ++i) {
@@ -502,7 +508,7 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
502508
* Apply logit processor.
503509
* @param {bigint[][]} input_ids The input IDs.
504510
* @param {Tensor} logits The logits.
505-
* @returns {Object} The processed logits.
511+
* @returns {Tensor} The processed logits.
506512
*/
507513
_call(input_ids, logits) {
508514
for (let i = 0; i < input_ids.length; ++i) {
@@ -535,7 +541,7 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
535541
* Apply logit processor.
536542
* @param {bigint[][]} input_ids The input IDs.
537543
* @param {Tensor} logits The logits.
538-
* @returns {Object} The processed logits.
544+
* @returns {Tensor} The processed logits.
539545
*/
540546
_call(input_ids, logits) {
541547
for (let i = 0; i < input_ids.length; ++i) {
@@ -596,7 +602,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
596602
* Apply logit processor.
597603
* @param {bigint[][]} input_ids The input IDs.
598604
* @param {Tensor} logits The logits.
599-
* @returns {Object} The processed logits.
605+
* @returns {Tensor} The processed logits.
600606
*/
601607
_call(input_ids, logits) {
602608
if (logits.dims[0] !== 2 * input_ids.length) {
@@ -650,7 +656,7 @@ export class TemperatureLogitsWarper extends LogitsWarper {
650656
* Apply logit warper.
651657
* @param {bigint[][]} input_ids The input IDs.
652658
* @param {Tensor} logits The logits.
653-
* @returns {Object} The processed logits.
659+
* @returns {Tensor} The processed logits.
654660
*/
655661
_call(input_ids, logits) {
656662
const batch_logits_data = /** @type {Float32Array} */(logits.data);

0 commit comments

Comments
 (0)