Skip to content

Commit 75b352c

Browse files
authored
Optimize added token split (#1261)
* Optimize split on added tokens * Add unit test for tokenizer with many added tokens * Add dictionary splitter unit tests * Minor improvements
1 parent c33b6a1 commit 75b352c

File tree

4 files changed

+144
-5
lines changed

4 files changed

+144
-5
lines changed

src/tokenizers.js

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import {
4242
PriorityQueue,
4343
TokenLattice,
4444
CharTrie,
45+
DictionarySplitter,
4546
} from './utils/data-structures.js';
4647

4748
import { Template } from '@huggingface/jinja';
@@ -2597,13 +2598,20 @@ export class PreTrainedTokenizer extends Callable {
25972598
this.decoder.end_of_word_suffix = this.model.end_of_word_suffix;
25982599
}
25992600

2600-
this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp(
2601-
this.added_tokens.slice()
2601+
// Divide added tokens into those that left/right strip, and those that don't
2602+
const added_tokens_with_strip = this.added_tokens.filter(x => x.rstrip || x.lstrip);
2603+
const added_tokens_without_strip = this.added_tokens.filter(x => !x.rstrip && !x.lstrip);
2604+
const split_regex = added_tokens_with_strip.length > 0 ? new RegExp(
2605+
added_tokens_with_strip.slice()
26022606
// Sort by length (desc) to avoid early partial matches
26032607
.sort((a, b) => b.content.length - a.content.length)
26042608
.map(x => `${x.lstrip ? '\\s*' : ''}(${escapeRegExp(x.content)})${x.rstrip ? '\\s*' : ''}`)
26052609
.join('|')
26062610
) : null;
2611+
this.added_tokens_splitter = new DictionarySplitter(
2612+
added_tokens_without_strip.map(x => x.content),
2613+
split_regex,
2614+
);
26072615

26082616
// Set mask token if present (otherwise will be undefined, which is fine)
26092617
this.mask_token = this.getToken('mask_token');
@@ -2898,8 +2906,7 @@ export class PreTrainedTokenizer extends Callable {
28982906
// Actual function which does encoding, for a single text
28992907
// First, we take care of special tokens. Needed to avoid issues arising from
29002908
// normalization and/or pretokenization (which may not preserve special tokens)
2901-
const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text];
2902-
2909+
const sections = this.added_tokens_splitter.split(text);
29032910
const tokens = sections.map((x, section_index) => {
29042911
const addedToken = this.added_tokens.find(t => t.content === x);
29052912
if (addedToken !== undefined) {

src/utils/data-structures.js

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,93 @@ class TokenLatticeNode {
445445
return n;
446446
}
447447
}
448+
449+
/**
450+
* A data structure which uses a trie to split a string into tokens based on a dictionary.
451+
* It can also use a regular expression to preprocess the input text before splitting.
452+
*
453+
* NOTE: To ensure multi-byte characters are handled correctly, we operate at byte-level instead of character-level.
454+
*/
455+
export class DictionarySplitter {
456+
/**
457+
* @param {string[]} dictionary The dictionary of words to use for splitting.
458+
* @param {RegExp} [splitRegex] Optional split regex for preprocessing the input text.
459+
*/
460+
constructor(dictionary, splitRegex = null) {
461+
this.trie = this._buildTrie(dictionary);
462+
this.splitRegex = splitRegex;
463+
}
464+
465+
/**
466+
* Builds a trie from the given dictionary.
467+
* @param {string[]} dictionary The dictionary of words to build the trie from.
468+
* @returns {Object} The root node of the trie.
469+
* @private
470+
*/
471+
_buildTrie(dictionary) {
472+
const trie = Object.create(null);
473+
for (const word of dictionary) {
474+
let node = trie;
475+
for (let i = 0; i < word.length; ++i) {
476+
node = (node[word[i]] ??= Object.create(null));
477+
}
478+
node.end = word;
479+
}
480+
return trie;
481+
}
482+
483+
/**
484+
* Splits the input text into tokens based on the dictionary.
485+
* @param {string} text The input text to split.
486+
* @returns {string[]} An array of tokens.
487+
*/
488+
split(text) {
489+
return this.splitRegex ?
490+
text.split(this.splitRegex)
491+
.filter(x => x)
492+
.flatMap(x => this._splitSingle(x))
493+
: this._splitSingle(text)
494+
}
495+
496+
/**
497+
* Helper function to split a single text string into tokens.
498+
* @param {string} text The input text to split.
499+
* @returns {string[]} An array of tokens.
500+
* @private
501+
*/
502+
_splitSingle(text) {
503+
const result = [];
504+
const n = text.length;
505+
let start = 0;
506+
let i = 0;
507+
508+
while (i < n) {
509+
let node = this.trie;
510+
let match = null;
511+
let j = i;
512+
513+
while (j < n && (node = node[text[j]])) {
514+
if (node.end) {
515+
// Always keep the last (i.e., longest) match.
516+
match = node.end;
517+
}
518+
++j;
519+
}
520+
521+
if (match) {
522+
if (i > start) {
523+
result.push(text.slice(start, i));
524+
}
525+
result.push(match);
526+
i += match.length;
527+
start = i;
528+
} else {
529+
++i;
530+
}
531+
}
532+
if (start < n) {
533+
result.push(text.slice(start));
534+
}
535+
return result;
536+
}
537+
}

tests/tokenizers.test.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,14 @@ describe("Edge cases", () => {
292292
},
293293
MAX_TEST_EXECUTION_TIME,
294294
);
295+
296+
it("many added tokens", async () => {
297+
let tokenizer = await AutoTokenizer.from_pretrained("onnx-community/orpheus-3b-0.1-ft-ONNX");
298+
299+
let text = "hello world!";
300+
let token_ids = tokenizer.encode(text);
301+
compare(token_ids, [128000, 15339, 1917, 0]);
302+
}, 5000); // NOTE: 5 seconds
295303
});
296304

297305
describe("Extra decoding tests", () => {

tests/utils/data_structures.test.js

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { PriorityQueue } from "../../src/utils/data-structures.js";
1+
import { PriorityQueue, DictionarySplitter } from "../../src/utils/data-structures.js";
22

33
describe("Priority queue", () => {
44
const EXAMPLE_ARRAY = [2, 5, 3, 1, 4];
@@ -31,3 +31,37 @@ describe("Priority queue", () => {
3131
}
3232
});
3333
});
34+
35+
describe("Dictionary splitter", () => {
36+
it("should split on a defined dictionary", () => {
37+
const splitter = new DictionarySplitter(
38+
["a", "b", "c", "abc"],
39+
null, // no split regex
40+
);
41+
const text = ".a.b.cc.abcdef.";
42+
const expected = [".", "a", ".", "b", ".", "c", "c", ".", "abc", "def."];
43+
const result = splitter.split(text);
44+
expect(result).toEqual(expected);
45+
});
46+
it("should split on a defined dictionary w/ split regex", () => {
47+
const splitter = new DictionarySplitter(
48+
["a", "b", "c", "abc"],
49+
/\s+/, // split on whitespace
50+
);
51+
const text = "a b c";
52+
const expected = ["a", "b", "c"];
53+
const result = splitter.split(text);
54+
expect(result).toEqual(expected);
55+
});
56+
57+
it("should handle multi-byte characters", () => {
58+
const text = "before🤗after\ud83etest";
59+
const splitter = new DictionarySplitter(
60+
["🤗" /* '\ud83e\udd17' */, "\ud83e"],
61+
null, // no split regex
62+
);
63+
const expected = ["before", "🤗", "after", "\ud83e", "test"];
64+
const result = splitter.split(text);
65+
expect(result).toEqual(expected);
66+
});
67+
});

0 commit comments

Comments
 (0)