diff --git a/src/tokenizers.js b/src/tokenizers.js index 83e33cc52..801abfda2 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -42,6 +42,7 @@ import { PriorityQueue, TokenLattice, CharTrie, + DictionarySplitter, } from './utils/data-structures.js'; import { Template } from '@huggingface/jinja'; @@ -2597,13 +2598,20 @@ export class PreTrainedTokenizer extends Callable { this.decoder.end_of_word_suffix = this.model.end_of_word_suffix; } - this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp( - this.added_tokens.slice() + // Divide added tokens into those that left/right strip, and those that don't + const added_tokens_with_strip = this.added_tokens.filter(x => x.rstrip || x.lstrip); + const added_tokens_without_strip = this.added_tokens.filter(x => !x.rstrip && !x.lstrip); + const split_regex = added_tokens_with_strip.length > 0 ? new RegExp( + added_tokens_with_strip.slice() // Sort by length (desc) to avoid early partial matches .sort((a, b) => b.content.length - a.content.length) .map(x => `${x.lstrip ? '\\s*' : ''}(${escapeRegExp(x.content)})${x.rstrip ? '\\s*' : ''}`) .join('|') ) : null; + this.added_tokens_splitter = new DictionarySplitter( + added_tokens_without_strip.map(x => x.content), + split_regex, + ); // Set mask token if present (otherwise will be undefined, which is fine) this.mask_token = this.getToken('mask_token'); @@ -2898,8 +2906,7 @@ export class PreTrainedTokenizer extends Callable { // Actual function which does encoding, for a single text // First, we take care of special tokens. Needed to avoid issues arising from // normalization and/or pretokenization (which may not preserve special tokens) - const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text]; - + const sections = this.added_tokens_splitter.split(text); const tokens = sections.map((x, section_index) => { const addedToken = this.added_tokens.find(t => t.content === x); if (addedToken !== undefined) { diff --git a/src/utils/data-structures.js b/src/utils/data-structures.js index 2340d12c0..1d8153bc0 100644 --- a/src/utils/data-structures.js +++ b/src/utils/data-structures.js @@ -445,3 +445,93 @@ class TokenLatticeNode { return n; } } + +/** + * A data structure which uses a trie to split a string into tokens based on a dictionary. + * It can also use a regular expression to preprocess the input text before splitting. + * + * NOTE: To ensure multi-byte characters are handled correctly, we operate at byte-level instead of character-level. + */ +export class DictionarySplitter { + /** + * @param {string[]} dictionary The dictionary of words to use for splitting. + * @param {RegExp} [splitRegex] Optional split regex for preprocessing the input text. + */ + constructor(dictionary, splitRegex = null) { + this.trie = this._buildTrie(dictionary); + this.splitRegex = splitRegex; + } + + /** + * Builds a trie from the given dictionary. + * @param {string[]} dictionary The dictionary of words to build the trie from. + * @returns {Object} The root node of the trie. + * @private + */ + _buildTrie(dictionary) { + const trie = Object.create(null); + for (const word of dictionary) { + let node = trie; + for (let i = 0; i < word.length; ++i) { + node = (node[word[i]] ??= Object.create(null)); + } + node.end = word; + } + return trie; + } + + /** + * Splits the input text into tokens based on the dictionary. + * @param {string} text The input text to split. + * @returns {string[]} An array of tokens. + */ + split(text) { + return this.splitRegex ? + text.split(this.splitRegex) + .filter(x => x) + .flatMap(x => this._splitSingle(x)) + : this._splitSingle(text) + } + + /** + * Helper function to split a single text string into tokens. + * @param {string} text The input text to split. + * @returns {string[]} An array of tokens. + * @private + */ + _splitSingle(text) { + const result = []; + const n = text.length; + let start = 0; + let i = 0; + + while (i < n) { + let node = this.trie; + let match = null; + let j = i; + + while (j < n && (node = node[text[j]])) { + if (node.end) { + // Always keep the last (i.e., longest) match. + match = node.end; + } + ++j; + } + + if (match) { + if (i > start) { + result.push(text.slice(start, i)); + } + result.push(match); + i += match.length; + start = i; + } else { + ++i; + } + } + if (start < n) { + result.push(text.slice(start)); + } + return result; + } +} diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index 943ce5898..2742513ee 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -292,6 +292,14 @@ describe("Edge cases", () => { }, MAX_TEST_EXECUTION_TIME, ); + + it("many added tokens", async () => { + let tokenizer = await AutoTokenizer.from_pretrained("onnx-community/orpheus-3b-0.1-ft-ONNX"); + + let text = "hello world!"; + let token_ids = tokenizer.encode(text); + compare(token_ids, [128000, 15339, 1917, 0]); + }, 5000); // NOTE: 5 seconds }); describe("Extra decoding tests", () => { diff --git a/tests/utils/data_structures.test.js b/tests/utils/data_structures.test.js index 033a91d00..dfb2db1d5 100644 --- a/tests/utils/data_structures.test.js +++ b/tests/utils/data_structures.test.js @@ -1,4 +1,4 @@ -import { PriorityQueue } from "../../src/utils/data-structures.js"; +import { PriorityQueue, DictionarySplitter } from "../../src/utils/data-structures.js"; describe("Priority queue", () => { const EXAMPLE_ARRAY = [2, 5, 3, 1, 4]; @@ -31,3 +31,37 @@ describe("Priority queue", () => { } }); }); + +describe("Dictionary splitter", () => { + it("should split on a defined dictionary", () => { + const splitter = new DictionarySplitter( + ["a", "b", "c", "abc"], + null, // no split regex + ); + const text = ".a.b.cc.abcdef."; + const expected = [".", "a", ".", "b", ".", "c", "c", ".", "abc", "def."]; + const result = splitter.split(text); + expect(result).toEqual(expected); + }); + it("should split on a defined dictionary w/ split regex", () => { + const splitter = new DictionarySplitter( + ["a", "b", "c", "abc"], + /\s+/, // split on whitespace + ); + const text = "a b c"; + const expected = ["a", "b", "c"]; + const result = splitter.split(text); + expect(result).toEqual(expected); + }); + + it("should handle multi-byte characters", () => { + const text = "before🤗after\ud83etest"; + const splitter = new DictionarySplitter( + ["🤗" /* '\ud83e\udd17' */, "\ud83e"], + null, // no split regex + ); + const expected = ["before", "🤗", "after", "\ud83e", "test"]; + const result = splitter.split(text); + expect(result).toEqual(expected); + }); +});