diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index f82634f75..1e5030dda 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -548,13 +548,15 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); const ids = input_ids[i]; for (const bad_word_ids of this.bad_words_ids) { + // There aren't enough tokens to match the banned sequence + if (ids.length < bad_word_ids.length - 1) continue; + // Whether to modify the logits of the last token in the bad word id sequence let mark = true; // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last), // then we set the logits of the last bad word id to -Infinity. - for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) { - + for (let j = 1; j <= bad_word_ids.length - 1; ++j) { // NOTE: We use != instead of !== to compare bigint and number // @ts-ignore if (bad_word_ids.at(-j - 1) != ids.at(-j)) { diff --git a/tests/utils/logits_process.test.js b/tests/utils/logits_process.test.js index 5da188ed4..38ca613d0 100644 --- a/tests/utils/logits_process.test.js +++ b/tests/utils/logits_process.test.js @@ -79,6 +79,31 @@ describe("Logits Processors", () => { }, MAX_TEST_EXECUTION_TIME, ); + + it( + "different lengths", + async () => { + const text_input = "this is a test"; + + const generated_text_target = "кт México constructed lake user"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + + const output = await pipe(text_input, { + max_new_tokens: 5, + bad_words_ids: [ + // default: [445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 16426n] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3931], // should never trigger (longer than input sequence) + + // block #1: [445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 16426n] + [3931, 14756, 7811], + + // result: [445n, 338n, 263n, 1243n, 3931n, 14756n, 13319n, 19437n, 1404n] + ], + }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); }); afterAll(async () => {