@@ -151,7 +151,7 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
151151 * Apply the BOS token forcing to the logits.
152152 * @param {bigint[][] } input_ids The input IDs.
153153 * @param {Tensor } logits The logits.
154- * @returns {Object } The logits with BOS token forcing.
154+ * @returns {Tensor } The logits with BOS token forcing.
155155 */
156156 _call ( input_ids , logits ) {
157157 for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -221,7 +221,7 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
221221 * Apply the BOS token forcing to the logits.
222222 * @param {bigint[][] } input_ids The input IDs.
223223 * @param {Tensor } logits The logits.
224- * @returns {Object } The logits with BOS token forcing.
224+ * @returns {Tensor } The logits with BOS token forcing.
225225 */
226226 _call ( input_ids , logits ) {
227227 for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -391,7 +391,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
391391 * Apply the no-repeat-ngram processor to the logits.
392392 * @param {bigint[][] } input_ids The input IDs.
393393 * @param {Tensor } logits The logits.
394- * @returns {Object } The logits with no-repeat-ngram processing.
394+ * @returns {Tensor } The logits with no-repeat-ngram processing.
395395 */
396396 _call ( input_ids , logits ) {
397397 for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -406,12 +406,22 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
406406}
407407
408408/**
409- * A logits processor that penalises repeated output tokens.
409+ * A logits processor that prevents the repetition of previous tokens through a penalty.
410+ * This penalty is applied at most once per token. Note that, for decoder-only models like most LLMs,
411+ * the considered tokens include the prompt.
412+ *
413+ * In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a
414+ * penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition.
415+ * To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes
416+ * more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where
417+ * a lower value rewards more strongly.
410418 */
411419export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
412420 /**
413421 * Create a RepetitionPenaltyLogitsProcessor.
414- * @param {number } penalty The penalty to apply for repeated tokens.
422+ * @param {number } penalty The parameter for repetition penalty.
423+ * - 1.0 means no penalty. Above 1.0 penalizes previously generated tokens.
424+ * - Between 0.0 and 1.0 rewards previously generated tokens.
415425 */
416426 constructor ( penalty ) {
417427 super ( ) ;
@@ -422,16 +432,12 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
422432 * Apply the repetition penalty to the logits.
423433 * @param {bigint[][] } input_ids The input IDs.
424434 * @param {Tensor } logits The logits.
425- * @returns {Object } The logits with repetition penalty processing.
435+ * @returns {Tensor } The logits with repetition penalty processing.
426436 */
427437 _call ( input_ids , logits ) {
428- // Modify the logits corresponding to each element in `input_ids`.
429- // As a consequence, the logits corresponding to tokens that appear
430- // many times in the output will be penalised more.
431-
432438 for ( let i = 0 ; i < input_ids . length ; ++ i ) {
433439 const batch_logits_data = /** @type {Float32Array } */ ( logits [ i ] . data ) ;
434- for ( const input_id of input_ids [ i ] ) {
440+ for ( const input_id of new Set ( input_ids [ i ] ) ) {
435441 const token = Number ( input_id ) ;
436442 if ( batch_logits_data [ token ] < 0 ) {
437443 batch_logits_data [ token ] *= this . penalty ;
@@ -464,7 +470,7 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
464470 * Apply logit processor.
465471 * @param {bigint[][] } input_ids The input IDs.
466472 * @param {Tensor } logits The logits.
467- * @returns {Object } The processed logits.
473+ * @returns {Tensor } The processed logits.
468474 */
469475 _call ( input_ids , logits ) {
470476 for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -502,7 +508,7 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
502508 * Apply logit processor.
503509 * @param {bigint[][] } input_ids The input IDs.
504510 * @param {Tensor } logits The logits.
505- * @returns {Object } The processed logits.
511+ * @returns {Tensor } The processed logits.
506512 */
507513 _call ( input_ids , logits ) {
508514 for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -535,7 +541,7 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
535541 * Apply logit processor.
536542 * @param {bigint[][] } input_ids The input IDs.
537543 * @param {Tensor } logits The logits.
538- * @returns {Object } The processed logits.
544+ * @returns {Tensor } The processed logits.
539545 */
540546 _call ( input_ids , logits ) {
541547 for ( let i = 0 ; i < input_ids . length ; ++ i ) {
@@ -596,7 +602,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
596602 * Apply logit processor.
597603 * @param {bigint[][] } input_ids The input IDs.
598604 * @param {Tensor } logits The logits.
599- * @returns {Object } The processed logits.
605+ * @returns {Tensor } The processed logits.
600606 */
601607 _call ( input_ids , logits ) {
602608 if ( logits . dims [ 0 ] !== 2 * input_ids . length ) {
@@ -650,7 +656,7 @@ export class TemperatureLogitsWarper extends LogitsWarper {
650656 * Apply logit warper.
651657 * @param {bigint[][] } input_ids The input IDs.
652658 * @param {Tensor } logits The logits.
653- * @returns {Object } The processed logits.
659+ * @returns {Tensor } The processed logits.
654660 */
655661 _call ( input_ids , logits ) {
656662 const batch_logits_data = /** @type {Float32Array } */ ( logits . data ) ;
0 commit comments