Skip to content

Commit 9f491b9

Browse files
committed
Add special tokens in text-generation pipeline if tokenizer requires
1 parent 2c32e1d commit 9f491b9

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/pipelines.js

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,11 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
996996
let isBatched = false;
997997
let isChatInput = false;
998998

999+
// By default, do not add special tokens, unless the tokenizer specifies otherwise
1000+
let add_special_tokens = generate_kwargs.add_special_tokens
1001+
?? (this.tokenizer.add_bos_token || this.tokenizer.add_eos_token)
1002+
?? false;
1003+
9991004
// Normalize inputs
10001005
/** @type {string[]} */
10011006
let inputs;
@@ -1021,11 +1026,9 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
10211026
add_generation_prompt: true,
10221027
})
10231028
));
1029+
add_special_tokens = false; // Chat template handles this already
10241030
}
10251031

1026-
// By default, do not add special tokens
1027-
const add_special_tokens = generate_kwargs.add_special_tokens ?? false;
1028-
10291032
// By default, return full text
10301033
const return_full_text = isChatInput
10311034
? false

src/tokenizers.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2659,6 +2659,9 @@ export class PreTrainedTokenizer extends Callable {
26592659
this.padding_side = tokenizerConfig.padding_side;
26602660
}
26612661

2662+
this.add_bos_token = tokenizerConfig.add_bos_token;
2663+
this.add_eos_token = tokenizerConfig.add_eos_token;
2664+
26622665
this.legacy = false;
26632666

26642667
this.chat_template = tokenizerConfig.chat_template ?? null;

tests/pipelines/test_pipelines_text_generation.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export default () => {
2020

2121
describe("batch_size=1", () => {
2222
const text_input = "hello";
23-
const generated_text_target = "erdingsAndroid Load";
23+
const generated_text_target = "erdingsdelete mely";
2424
const text_target = [{ generated_text: text_input + generated_text_target }];
2525
const new_text_target = [{ generated_text: generated_text_target }];
2626

0 commit comments

Comments
 (0)