Skip to content

Commit db9250b

Browse files
authored
Add sequence post processor (#771)
* Add `Sequence` PostProcessor Required by https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/ * Support `return_token_type_ids` * Add llama3 tokenizer to unit tests * Add test for allowing user to request for token type ids * Add JSDoc * Update generate_tests.py
1 parent e50f568 commit db9250b

File tree

3 files changed

+100
-6
lines changed

3 files changed

+100
-6
lines changed

src/tokenizers.js

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,8 @@ class PostProcessor extends Callable {
15921592
case 'BertProcessing':
15931593
return new BertProcessing(config);
15941594

1595+
case 'Sequence':
1596+
return new PostProcessorSequence(config);
15951597
default:
15961598
throw new Error(`Unknown PostProcessor type: ${config.type}`);
15971599
}
@@ -1738,6 +1740,50 @@ class ByteLevelPostProcessor extends PostProcessor {
17381740
}
17391741
}
17401742

1743+
1744+
/**
1745+
* A post-processor that applies multiple post-processors in sequence.
1746+
*/
1747+
class PostProcessorSequence extends PostProcessor {
1748+
1749+
/**
1750+
* Creates a new instance of PostProcessorSequence.
1751+
* @param {Object} config The configuration object.
1752+
* @param {Object[]} config.processors The list of post-processors to apply.
1753+
*/
1754+
constructor(config) {
1755+
super(config);
1756+
1757+
this.processors = config.processors.map(x => PostProcessor.fromConfig(x));
1758+
}
1759+
1760+
/**
1761+
* Post process the given tokens.
1762+
* @param {string[]} tokens The list of tokens for the first sequence.
1763+
* @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional).
1764+
* @returns {PostProcessedOutput} An object containing the post-processed tokens.
1765+
*/
1766+
post_process(tokens, tokens_pair = null, options = {}) {
1767+
let token_type_ids;
1768+
for (const processor of this.processors) {
1769+
if (processor instanceof ByteLevelPostProcessor) {
1770+
// Special case where we need to pass the tokens_pair to the post-processor
1771+
const output = processor.post_process(tokens);
1772+
tokens = output.tokens;
1773+
if (tokens_pair) {
1774+
const pair_output = processor.post_process(tokens_pair);
1775+
tokens_pair = pair_output.tokens;
1776+
}
1777+
} else {
1778+
const output = processor.post_process(tokens, tokens_pair, options);
1779+
tokens = output.tokens;
1780+
token_type_ids = output.token_type_ids;
1781+
}
1782+
}
1783+
return { tokens, token_type_ids };
1784+
}
1785+
}
1786+
17411787
/**
17421788
* The base class for token decoders.
17431789
* @extends Callable
@@ -2100,7 +2146,7 @@ class DecoderSequence extends Decoder {
21002146
/**
21012147
* Creates a new instance of DecoderSequence.
21022148
* @param {Object} config The configuration object.
2103-
* @param {Decoder[]} config.decoders The list of decoders to apply.
2149+
* @param {Object[]} config.decoders The list of decoders to apply.
21042150
*/
21052151
constructor(config) {
21062152
super(config);
@@ -2623,6 +2669,7 @@ export class PreTrainedTokenizer extends Callable {
26232669
* @param {boolean} [options.truncation=null] Whether to truncate the input sequences.
26242670
* @param {number} [options.max_length=null] Maximum length of the returned list and optionally padding length.
26252671
* @param {boolean} [options.return_tensor=true] Whether to return the results as Tensors or arrays.
2672+
* @param {boolean} [options.return_token_type_ids=null] Whether to return the token type ids.
26262673
* @returns {BatchEncoding} Object to be passed to the model.
26272674
*/
26282675
_call(
@@ -2637,6 +2684,7 @@ export class PreTrainedTokenizer extends Callable {
26372684
truncation = null,
26382685
max_length = null,
26392686
return_tensor = true, // Different to HF
2687+
return_token_type_ids = null,
26402688
} = {},
26412689
) {
26422690

@@ -2659,11 +2707,11 @@ export class PreTrainedTokenizer extends Callable {
26592707
}
26602708

26612709
encodedTokens = text.map(
2662-
(t, i) => this._encode_plus(t, text_pair[i], { add_special_tokens })
2710+
(t, i) => this._encode_plus(t, text_pair[i], { add_special_tokens, return_token_type_ids })
26632711
)
26642712

26652713
} else {
2666-
encodedTokens = text.map(x => this._encode_plus(x, null, { add_special_tokens }));
2714+
encodedTokens = text.map(x => this._encode_plus(x, null, { add_special_tokens, return_token_type_ids }));
26672715
}
26682716

26692717
} else {
@@ -2676,7 +2724,7 @@ export class PreTrainedTokenizer extends Callable {
26762724
}
26772725

26782726
// For single input, we just wrap in an array, and then unwrap later.
2679-
encodedTokens = [this._encode_plus(text, text_pair, { add_special_tokens })];
2727+
encodedTokens = [this._encode_plus(text, text_pair, { add_special_tokens, return_token_type_ids })];
26802728
}
26812729
// At this point, tokens is batched: [batch_size, tokens]
26822730
// However, array may be jagged. So, we pad to max_length
@@ -2834,11 +2882,13 @@ export class PreTrainedTokenizer extends Callable {
28342882
* @param {string|null} text_pair The optional second text to encode.
28352883
* @param {Object} options An optional object containing the following properties:
28362884
* @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model.
2885+
* @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids.
28372886
* @returns {EncodingSingle} An object containing the encoded text.
28382887
* @private
28392888
*/
28402889
_encode_plus(text, text_pair = null, {
28412890
add_special_tokens = true,
2891+
return_token_type_ids = null,
28422892
} = {}) {
28432893
// Function called by users to encode possibly multiple texts
28442894
const tokens = this._encode_text(text);
@@ -2854,7 +2904,7 @@ export class PreTrainedTokenizer extends Callable {
28542904
input_ids,
28552905
attention_mask: new Array(input_ids.length).fill(1),
28562906
}
2857-
if (this.return_token_type_ids && combinedTokens.token_type_ids) {
2907+
if ((return_token_type_ids ?? this.return_token_type_ids) && combinedTokens.token_type_ids) {
28582908
result.token_type_ids = combinedTokens.token_type_ids;
28592909
}
28602910
return result;
@@ -2867,13 +2917,16 @@ export class PreTrainedTokenizer extends Callable {
28672917
* @param {string|null} text_pair The optional second text to encode.
28682918
* @param {Object} options An optional object containing the following properties:
28692919
* @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model.
2920+
* @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids.
28702921
* @returns {number[]} An array of token IDs representing the encoded text(s).
28712922
*/
28722923
encode(text, text_pair = null, {
28732924
add_special_tokens = true,
2925+
return_token_type_ids = null,
28742926
} = {}) {
28752927
const { input_ids } = this._encode_plus(text, text_pair, {
28762928
add_special_tokens,
2929+
return_token_type_ids,
28772930
});
28782931
return input_ids;
28792932
}

tests/generate_tests.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
'Xenova/llama2-tokenizer', # Special tokens: normalized=false
2121
'Xenova/llama2-chat-tokenizer', # Special tokens: normalized=false
2222
'hf-internal-testing/llama-code-tokenizer',
23+
24+
# TODO: add back when llama tests are fixed
25+
# 'Xenova/llama3-tokenizer-new', # PostProcessor type: Sequence
2326
],
2427
'mpt': [
2528
'mosaicml/mpt-7b',
@@ -289,7 +292,7 @@ def generate_tokenizer_tests():
289292
# Load tokenizer
290293
if model_type == 'llama':
291294
# As of 17/12/2023, there are a few issues with the Llama tokenizers in transformers.
292-
# (1) Encoding with fast tokenizer adds whitespace after speical tokens:
295+
# (1) Encoding with fast tokenizer adds whitespace after special tokens:
293296
# - https://github.com/huggingface/transformers/issues/25881
294297
# - https://github.com/huggingface/transformers/issues/26318
295298
# - https://github.com/huggingface/transformers/issues/26455

tests/tokenizers.test.js

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,44 @@ describe('Token type ids', () => {
288288
compare(model_inputs, expected);
289289

290290
}, MAX_TEST_EXECUTION_TIME);
291+
292+
it('should add token type ids if user requests them', async () => {
293+
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/llama3-tokenizer-new');
294+
295+
{ // Without text pair
296+
const model_inputs = tokenizer(
297+
'hello',
298+
{
299+
return_tensor: false,
300+
return_token_type_ids: true,
301+
}
302+
);
303+
const expected = {
304+
input_ids: [128000, 15339],
305+
attention_mask: [1, 1],
306+
token_type_ids: [0, 0]
307+
}
308+
compare(model_inputs, expected);
309+
}
310+
311+
{ // With text pair
312+
const model_inputs = tokenizer(
313+
'hello',
314+
{
315+
text_pair: 'world',
316+
return_tensor: false,
317+
return_token_type_ids: true,
318+
}
319+
);
320+
const expected = {
321+
input_ids: [128000, 15339, 128000, 14957],
322+
attention_mask: [1, 1, 1, 1],
323+
token_type_ids: [0, 0, 1, 1]
324+
}
325+
compare(model_inputs, expected);
326+
}
327+
328+
}, MAX_TEST_EXECUTION_TIME);
291329
});
292330

293331
describe('Edge cases', () => {

0 commit comments

Comments
 (0)