Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import {
PriorityQueue,
TokenLattice,
CharTrie,
DictionarySplitter,
} from './utils/data-structures.js';

import { Template } from '@huggingface/jinja';
Expand Down Expand Up @@ -2597,13 +2598,20 @@ export class PreTrainedTokenizer extends Callable {
this.decoder.end_of_word_suffix = this.model.end_of_word_suffix;
}

this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp(
this.added_tokens.slice()
// Divide added tokens into those that left/right strip, and those that don't
const added_tokens_with_strip = this.added_tokens.filter(x => x.rstrip || x.lstrip);
const added_tokens_without_strip = this.added_tokens.filter(x => !x.rstrip && !x.lstrip);
const split_regex = added_tokens_with_strip.length > 0 ? new RegExp(
added_tokens_with_strip.slice()
// Sort by length (desc) to avoid early partial matches
.sort((a, b) => b.content.length - a.content.length)
.map(x => `${x.lstrip ? '\\s*' : ''}(${escapeRegExp(x.content)})${x.rstrip ? '\\s*' : ''}`)
.join('|')
) : null;
this.added_tokens_splitter = new DictionarySplitter(
added_tokens_without_strip.map(x => x.content),
split_regex,
);

// Set mask token if present (otherwise will be undefined, which is fine)
this.mask_token = this.getToken('mask_token');
Expand Down Expand Up @@ -2898,8 +2906,7 @@ export class PreTrainedTokenizer extends Callable {
// Actual function which does encoding, for a single text
// First, we take care of special tokens. Needed to avoid issues arising from
// normalization and/or pretokenization (which may not preserve special tokens)
const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text];

const sections = this.added_tokens_splitter.split(text);
const tokens = sections.map((x, section_index) => {
const addedToken = this.added_tokens.find(t => t.content === x);
if (addedToken !== undefined) {
Expand Down
90 changes: 90 additions & 0 deletions src/utils/data-structures.js
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,93 @@ class TokenLatticeNode {
return n;
}
}

/**
* A data structure which uses a trie to split a string into tokens based on a dictionary.
* It can also use a regular expression to preprocess the input text before splitting.
*
* NOTE: To ensure multi-byte characters are handled correctly, we operate at byte-level instead of character-level.
*/
export class DictionarySplitter {
/**
* @param {string[]} dictionary The dictionary of words to use for splitting.
* @param {RegExp} [splitRegex] Optional split regex for preprocessing the input text.
*/
constructor(dictionary, splitRegex = null) {
this.trie = this._buildTrie(dictionary);
this.splitRegex = splitRegex;
}

/**
* Builds a trie from the given dictionary.
* @param {string[]} dictionary The dictionary of words to build the trie from.
* @returns {Object} The root node of the trie.
* @private
*/
_buildTrie(dictionary) {
const trie = Object.create(null);
for (const word of dictionary) {
let node = trie;
for (let i = 0; i < word.length; ++i) {
node = (node[word[i]] ??= Object.create(null));
}
node.end = word;
}
return trie;
}

/**
* Splits the input text into tokens based on the dictionary.
* @param {string} text The input text to split.
* @returns {string[]} An array of tokens.
*/
split(text) {
return this.splitRegex ?
text.split(this.splitRegex)
.filter(x => x)
.flatMap(x => this._splitSingle(x))
: this._splitSingle(text)
}

/**
* Helper function to split a single text string into tokens.
* @param {string} text The input text to split.
* @returns {string[]} An array of tokens.
* @private
*/
_splitSingle(text) {
const result = [];
const n = text.length;
let start = 0;
let i = 0;

while (i < n) {
let node = this.trie;
let match = null;
let j = i;

while (j < n && (node = node[text[j]])) {
if (node.end) {
// Always keep the last (i.e., longest) match.
match = node.end;
}
++j;
}

if (match) {
if (i > start) {
result.push(text.slice(start, i));
}
result.push(match);
i += match.length;
start = i;
} else {
++i;
}
}
if (start < n) {
result.push(text.slice(start));
}
return result;
}
}
8 changes: 8 additions & 0 deletions tests/tokenizers.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ describe("Edge cases", () => {
},
MAX_TEST_EXECUTION_TIME,
);

it("many added tokens", async () => {
let tokenizer = await AutoTokenizer.from_pretrained("onnx-community/orpheus-3b-0.1-ft-ONNX");

let text = "hello world!";
let token_ids = tokenizer.encode(text);
compare(token_ids, [128000, 15339, 1917, 0]);
}, 5000); // NOTE: 5 seconds
});

describe("Extra decoding tests", () => {
Expand Down
36 changes: 35 additions & 1 deletion tests/utils/data_structures.test.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { PriorityQueue } from "../../src/utils/data-structures.js";
import { PriorityQueue, DictionarySplitter } from "../../src/utils/data-structures.js";

describe("Priority queue", () => {
const EXAMPLE_ARRAY = [2, 5, 3, 1, 4];
Expand Down Expand Up @@ -31,3 +31,37 @@ describe("Priority queue", () => {
}
});
});

describe("Dictionary splitter", () => {
it("should split on a defined dictionary", () => {
const splitter = new DictionarySplitter(
["a", "b", "c", "abc"],
null, // no split regex
);
const text = ".a.b.cc.abcdef.";
const expected = [".", "a", ".", "b", ".", "c", "c", ".", "abc", "def."];
const result = splitter.split(text);
expect(result).toEqual(expected);
});
it("should split on a defined dictionary w/ split regex", () => {
const splitter = new DictionarySplitter(
["a", "b", "c", "abc"],
/\s+/, // split on whitespace
);
const text = "a b c";
const expected = ["a", "b", "c"];
const result = splitter.split(text);
expect(result).toEqual(expected);
});

it("should handle multi-byte characters", () => {
const text = "before🤗after\ud83etest";
const splitter = new DictionarySplitter(
["🤗" /* '\ud83e\udd17' */, "\ud83e"],
null, // no split regex
);
const expected = ["before", "🤗", "after", "\ud83e", "test"];
const result = splitter.split(text);
expect(result).toEqual(expected);
});
});