Skip to content

Commit c2c45cb

Browse files
authored
Improve support of conversational models (#658)
* Add `return_full_text` option for text-generation models * [wip] Support chat inputs in text-generation pipeline * Align return type with python version * Remove conversational task (moved to text-generation) * Fix typos
1 parent aa542cf commit c2c45cb

File tree

5 files changed

+101
-16
lines changed

5 files changed

+101
-16
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ You can refine your search by selecting the task you're interested in (e.g., [te
198198

199199
| Task | ID | Description | Supported? |
200200
|--------------------------|----|-------------|------------|
201-
| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. ||
202201
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
203202
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
204203
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |

docs/snippets/5_supported-tasks.snippet

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
| Task | ID | Description | Supported? |
77
|--------------------------|----|-------------|------------|
8-
| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
98
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
109
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
1110
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |

src/pipelines.js

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -841,18 +841,24 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC
841841
}
842842
}
843843

844+
function isChat(x) {
845+
return Array.isArray(x) && x.every(x => 'role' in x && 'content' in x);
846+
}
844847

845848
/**
849+
* @typedef {import('./tokenizers.js').Message[]} Chat
850+
*
846851
* @typedef {Object} TextGenerationSingle
847-
* @property {string} generated_text The generated text.
852+
* @property {string|Chat} generated_text The generated text.
848853
* @typedef {TextGenerationSingle[]} TextGenerationOutput
849854
*
850855
* @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines.
851856
* @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences.
857+
* @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned.
852858
* @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig
853859
*
854860
* @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs.
855-
* @param {string|string[]} texts One or several prompts (or one list of prompts) to complete.
861+
* @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete.
856862
* @param {TextGenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
857863
* @returns {Promise<TextGenerationOutput|TextGenerationOutput[]>} An array or object containing the generated texts.
858864
*
@@ -921,17 +927,46 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
921927

922928
/** @type {TextGenerationPipelineCallback} */
923929
async _call(texts, generate_kwargs = {}) {
930+
let isBatched = false;
931+
let isChatInput = false;
932+
933+
// Normalize inputs
934+
/** @type {string[]} */
935+
let inputs;
936+
if (typeof texts === 'string') {
937+
inputs = texts = [texts];
938+
} else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) {
939+
isBatched = true;
940+
inputs = /** @type {string[]} */(texts);
941+
} else {
942+
if (isChat(texts)) {
943+
texts = [/** @type {Chat} */(texts)];
944+
} else if (Array.isArray(texts) && texts.every(isChat)) {
945+
isBatched = true;
946+
} else {
947+
throw new Error('Input must be a string, an array of strings, a Chat, or an array of Chats');
948+
}
949+
isChatInput = true;
924950

925-
const isBatched = Array.isArray(texts);
926-
if (!isBatched) {
927-
texts = [/** @type {string}*/ (texts)];
951+
// If the input is a chat, we need to apply the chat template
952+
inputs = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map(
953+
x => this.tokenizer.apply_chat_template(x, {
954+
tokenize: false,
955+
add_generation_prompt: true,
956+
})
957+
));
928958
}
929959

930960
// By default, do not add special tokens
931961
const add_special_tokens = generate_kwargs.add_special_tokens ?? false;
932962

963+
// By default, return full text
964+
const return_full_text = isChatInput
965+
? false
966+
: generate_kwargs.return_full_text ?? true;
967+
933968
this.tokenizer.padding_side = 'left';
934-
const { input_ids, attention_mask } = this.tokenizer(texts, {
969+
const { input_ids, attention_mask } = this.tokenizer(inputs, {
935970
add_special_tokens,
936971
padding: true,
937972
truncation: true,
@@ -941,17 +976,34 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
941976
inputs_attention_mask: attention_mask
942977
});
943978

944-
const decoded = this.tokenizer.batch_decode(outputTokenIds, {
979+
let decoded = this.tokenizer.batch_decode(outputTokenIds, {
945980
skip_special_tokens: true,
946981
});
947982

983+
984+
let promptLengths;
985+
if (!return_full_text && input_ids.dims.at(-1) > 0) {
986+
promptLengths = this.tokenizer.batch_decode(input_ids, {
987+
skip_special_tokens: true,
988+
}).map(x => x.length);
989+
}
990+
948991
/** @type {TextGenerationOutput[]} */
949992
const toReturn = Array.from({ length: texts.length }, _ => []);
950993
for (let i = 0; i < decoded.length; ++i) {
951994
const textIndex = Math.floor(i / outputTokenIds.length * texts.length);
952995

996+
if (promptLengths) {
997+
// Trim the decoded text to only include the generated part
998+
decoded[i] = decoded[i].slice(promptLengths[textIndex]);
999+
}
9531000
toReturn[textIndex].push({
954-
generated_text: decoded[i]
1001+
generated_text: isChatInput
1002+
? [
1003+
...((/** @type {Chat[]} */(texts)[textIndex])),
1004+
{ role: 'assistant', content: decoded[i] },
1005+
]
1006+
: decoded[i]
9551007
});
9561008
}
9571009
return (!isBatched && toReturn.length === 1) ? toReturn[0] : toReturn;

src/tokenizers.js

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2429,6 +2429,12 @@ function truncateHelper(item, length) {
24292429
}
24302430

24312431

2432+
/**
2433+
* @typedef {Object} Message
2434+
* @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
2435+
* @property {string} content The content of the message.
2436+
*/
2437+
24322438
export class PreTrainedTokenizer extends Callable {
24332439
return_token_type_ids = false;
24342440

@@ -2959,12 +2965,6 @@ export class PreTrainedTokenizer extends Callable {
29592965
return this._default_chat_template;
29602966
}
29612967

2962-
/**
2963-
* @typedef {Object} Message
2964-
* @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
2965-
* @property {string} content The content of the message.
2966-
*/
2967-
29682968
/**
29692969
* Converts a list of message objects with `"role"` and `"content"` keys to a list of token
29702970
* ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to

tests/generation.test.js

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ describe('Generation parameters', () => {
1111
const models = [
1212
'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder
1313
'MBZUAI/LaMini-GPT-124M', // decoder-only
14+
15+
'Xenova/llama2.c-stories15M', // decoder-only
1416
];
1517

1618
// encoder-decoder model
@@ -135,4 +137,37 @@ describe('Generation parameters', () => {
135137

136138
}, MAX_TEST_EXECUTION_TIME);
137139

140+
// decoder-only model
141+
it(models[2], async () => {
142+
const MAX_NEW_TOKENS = 1;
143+
144+
const text = [
145+
'Once upon a time,',
146+
'Lily',
147+
'Suddenly,',
148+
];
149+
150+
const generator = await pipeline('text-generation', m(models[2]));
151+
152+
{ // return_full_text=false
153+
const output = await generator(text, {
154+
return_full_text: false,
155+
max_new_tokens: MAX_NEW_TOKENS,
156+
num_beams: 2,
157+
num_return_sequences: 2,
158+
});
159+
const lengths = output.flatMap(
160+
x => x.flatMap(
161+
y => generator.tokenizer.encode(y.generated_text.trim(), null, {
162+
add_special_tokens: false,
163+
}).length
164+
)
165+
).every(x => x === MAX_NEW_TOKENS);
166+
167+
expect(lengths).toBe(true);
168+
}
169+
await generator.dispose();
170+
171+
}, MAX_TEST_EXECUTION_TIME);
172+
138173
});

0 commit comments

Comments
 (0)