Skip to content

Commit 060ac83

Browse files
authored
Add M2M100 tokenizer (Closes #235) (#250)
* Add `M2M100Tokenizer` * Allow `added_tokens` list to be empty * Apply hot-fix for issue in HF's `M2M100Tokenizer` * Skip M2M100 tokenizer tests for now TODO: Remove when huggingface/transformers#25478 is merged * Fix `_build_translation_inputs` for `M2M100Tokenizer` * Add example code in JSDoc for `TranslationPipeline` * Update supported_models.py
1 parent cc4b857 commit 060ac83

File tree

4 files changed

+131
-31
lines changed

4 files changed

+131
-31
lines changed

scripts/supported_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
],
141141
'm2m_100': [
142142
'facebook/nllb-200-distilled-600M',
143+
'facebook/m2m100_418M',
143144
],
144145
# TODO:
145146
# 'marian': [

src/pipelines.js

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,36 @@ export class SummarizationPipeline extends Text2TextGenerationPipeline {
452452
}
453453

454454
/**
455-
* TranslationPipeline class to translate text from one language to another using the provided model and tokenizer.
456-
* @extends Text2TextGenerationPipeline
455+
* Translates text from one language to another.
456+
*
457+
* **Example:** Multilingual translation w/ `Xenova/nllb-200-distilled-600M`.
458+
*
459+
* See [here](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200)
460+
* for the full list of languages and their corresponding codes.
461+
*
462+
* ```javascript
463+
* let translator = await pipeline('translation', 'Xenova/nllb-200-distilled-600M');
464+
* let output = await translator('जीवन एक चॉकलेट बॉक्स की तरह है।', {
465+
* src_lang: 'hin_Deva', // Hindi
466+
* tgt_lang: 'fra_Latn', // French
467+
* });
468+
* // [ { translation_text: 'La vie est comme une boîte à chocolat.' } ]
469+
* ```
470+
*
471+
* **Example:** Multilingual translation w/ `Xenova/m2m100_418M`.
472+
*
473+
* See [here](https://huggingface.co/facebook/m2m100_418M#languages-covered)
474+
* for the full list of languages and their corresponding codes.
475+
*
476+
* ```javascript
477+
* let translator = await pipeline('translation', 'Xenova/m2m100_418M');
478+
* let output = await translator('生活就像一盒巧克力。', {
479+
* src_lang: 'zh', // Chinese
480+
* tgt_lang: 'en', // English
481+
* });
482+
* // [ { translation_text: 'Life is like a box of chocolate.' } ]
483+
* ```
484+
*
457485
*/
458486
export class TranslationPipeline extends Text2TextGenerationPipeline {
459487
_key = 'translation_text';

src/tokenizers.js

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,12 +1995,16 @@ export class PreTrainedTokenizer extends Callable {
19951995
}
19961996
}
19971997

1998+
// Update additional_special_tokens
1999+
this.special_tokens.push(...(tokenizerConfig.additional_special_tokens ?? []));
2000+
this.special_tokens = [...new Set(this.special_tokens)]; // Remove duplicates
2001+
19982002
// Slight hack, but it prevents code duplication:
19992003
this.decoder.added_tokens = this.added_tokens;
20002004

2001-
this.added_tokens_regex = new RegExp(
2005+
this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp(
20022006
'(' + this.added_tokens.map(escapeRegExp).join('|') + ')'
2003-
);
2007+
) : null;
20042008

20052009
// Set mask token if present (otherwise will be undefined, which is fine)
20062010
this.mask_token = this.getToken(tokenizerConfig, 'mask_token');
@@ -2265,8 +2269,7 @@ export class PreTrainedTokenizer extends Callable {
22652269
// Actual function which does encoding, for a single text
22662270
// First, we take care of special tokens. Needed to avoid issues arising from
22672271
// normalization and/or pretokenization (which may not preserve special tokens)
2268-
const sections = text.split(this.added_tokens_regex).filter(x => x);
2269-
2272+
const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text];
22702273
let tokens = sections.map(x => {
22712274
if (this.added_tokens.includes(x)) {
22722275
// Ignore added tokens
@@ -2482,6 +2485,58 @@ export class FalconTokenizer extends PreTrainedTokenizer {
24822485

24832486
export class GPTNeoXTokenizer extends PreTrainedTokenizer { }
24842487

2488+
2489+
/**
2490+
* Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`.
2491+
* @param {PreTrainedTokenizer} self The tokenizer instance.
2492+
* @param {string|string[]} raw_inputs The text to tokenize.
2493+
* @param {Object} tokenizer_options Options to be sent to the tokenizer
2494+
* @param {Object} generate_kwargs Generation options.
2495+
* @returns {Object} Object to be passed to the model.
2496+
* @private
2497+
*/
2498+
function _build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs) {
2499+
if (!('language_codes' in self) || !Array.isArray(self.language_codes)) {
2500+
throw new Error('Tokenizer must have `language_codes` attribute set and it should be an array of language ids.')
2501+
}
2502+
if (!('languageRegex' in self) || !(self.languageRegex instanceof RegExp)) {
2503+
throw new Error('Tokenizer must have `languageRegex` attribute set and it should be a regular expression.')
2504+
}
2505+
if (!('lang_to_token' in self) || typeof self.lang_to_token !== 'function') {
2506+
throw new Error('Tokenizer must have `lang_to_token` attribute set and it should be a function.')
2507+
}
2508+
const src_lang_token = generate_kwargs.src_lang;
2509+
const tgt_lang_token = generate_kwargs.tgt_lang;
2510+
2511+
// Check that the target language is valid:
2512+
if (!self.language_codes.includes(tgt_lang_token)) {
2513+
throw new Error(`Target language code "${tgt_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`);
2514+
}
2515+
2516+
// Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default.
2517+
if (src_lang_token !== undefined) {
2518+
// Check that the source language is valid:
2519+
if (!self.language_codes.includes(src_lang_token)) {
2520+
throw new Error(`Source language code "${src_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`);
2521+
}
2522+
2523+
// In the same way as the Python library, we override the post-processor
2524+
// to force the source language to be first:
2525+
for (let item of self.post_processor.config.single) {
2526+
if ('SpecialToken' in item && self.languageRegex.test(item.SpecialToken.id)) {
2527+
item.SpecialToken.id = self.lang_to_token(src_lang_token);
2528+
break;
2529+
}
2530+
}
2531+
// TODO: Do the same for pair?
2532+
}
2533+
2534+
// Override the `forced_bos_token_id` to force the correct language
2535+
generate_kwargs.forced_bos_token_id = self.model.convert_tokens_to_ids([self.lang_to_token(tgt_lang_token)])[0];
2536+
2537+
return self._call(raw_inputs, tokenizer_options);
2538+
}
2539+
24852540
/**
24862541
* The NllbTokenizer class is used to tokenize text for NLLB ("No Language Left Behind") models.
24872542
*
@@ -2502,6 +2557,7 @@ export class NllbTokenizer extends PreTrainedTokenizer {
25022557

25032558
this.languageRegex = /^[a-z]{3}_[A-Z][a-z]{3}$/;
25042559
this.language_codes = this.special_tokens.filter(x => this.languageRegex.test(x));
2560+
this.lang_to_token = x => x; // Identity function
25052561
}
25062562

25072563
/**
@@ -2512,34 +2568,40 @@ export class NllbTokenizer extends PreTrainedTokenizer {
25122568
* @returns {Object} Object to be passed to the model.
25132569
*/
25142570
_build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) {
2571+
return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs);
2572+
}
2573+
}
25152574

2575+
/**
2576+
* The M2M100Tokenizer class is used to tokenize text for M2M100 ("Many-to-Many") models.
2577+
*
2578+
* M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many
2579+
* multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125)
2580+
* and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
2581+
*
2582+
* For a list of supported languages (along with their language codes),
2583+
* @see {@link https://huggingface.co/facebook/m2m100_418M#languages-covered}
2584+
*/
2585+
export class M2M100Tokenizer extends PreTrainedTokenizer {
2586+
constructor(tokenizerJSON, tokenizerConfig) {
2587+
super(tokenizerJSON, tokenizerConfig);
25162588

2517-
// Check that the target language is valid:
2518-
if (!this.language_codes.includes(generate_kwargs.tgt_lang)) {
2519-
throw new Error(`Target language code "${generate_kwargs.tgt_lang}" is not valid. Must be one of: {${this.language_codes.join(', ')}}`);
2520-
}
2521-
2522-
// Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default.
2523-
if (generate_kwargs.src_lang !== undefined) {
2524-
// Check that the source language is valid:
2525-
if (!this.language_codes.includes(generate_kwargs.src_lang)) {
2526-
throw new Error(`Source language code "${generate_kwargs.src_lang}" is not valid. Must be one of: {${this.language_codes.join(', ')}}`);
2527-
}
2528-
2529-
// In the same way as the Python library, we override the post-processor
2530-
// to force the source language to be first:
2531-
for (let item of this.post_processor.config.single) {
2532-
if ('SpecialToken' in item && this.languageRegex.test(item.SpecialToken.id)) {
2533-
item.SpecialToken.id = generate_kwargs.src_lang;
2534-
break;
2535-
}
2536-
}
2537-
}
2538-
2539-
// Override the `forced_bos_token_id` to force the correct language
2540-
generate_kwargs.forced_bos_token_id = this.model.convert_tokens_to_ids([generate_kwargs.tgt_lang])[0];
2589+
this.languageRegex = /^__[a-z]{2,3}__$/;
2590+
this.language_codes = this.special_tokens
2591+
.filter(x => this.languageRegex.test(x))
2592+
.map(x => x.slice(2, -2));
2593+
this.lang_to_token = x => `__${x}__`;
2594+
}
25412595

2542-
return this._call(raw_inputs, tokenizer_options);
2596+
/**
2597+
* Helper function to build translation inputs for an `M2M100Tokenizer`.
2598+
* @param {string|string[]} raw_inputs The text to tokenize.
2599+
* @param {Object} tokenizer_options Options to be sent to the tokenizer
2600+
* @param {Object} generate_kwargs Generation options.
2601+
* @returns {Object} Object to be passed to the model.
2602+
*/
2603+
_build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) {
2604+
return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs);
25432605
}
25442606
}
25452607

@@ -3485,6 +3547,7 @@ export class AutoTokenizer {
34853547
'MarianTokenizer': MarianTokenizer,
34863548
'BloomTokenizer': BloomTokenizer,
34873549
'NllbTokenizer': NllbTokenizer,
3550+
'M2M100Tokenizer': M2M100Tokenizer,
34883551
'LlamaTokenizer': LlamaTokenizer,
34893552
'XLMRobertaTokenizer': XLMRobertaTokenizer,
34903553
'MPNetTokenizer': MPNetTokenizer,

tests/generate_tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
],
2222
}
2323

24+
TOKENIZERS_TO_IGNORE = [
25+
# TODO: remove when https://github.com/huggingface/transformers/pull/25478 is merged
26+
'facebook/m2m100_418M',
27+
]
28+
2429
TOKENIZER_TEST_DATA = {
2530
"shared": [
2631
"hello world",
@@ -92,6 +97,9 @@ def generate_tokenizer_tests():
9297
for model_type, tokenizer_names in tokenizers_to_test:
9398
print(f'Generating tests for {model_type}')
9499
for tokenizer_name in tokenizer_names:
100+
if tokenizer_name in TOKENIZERS_TO_IGNORE:
101+
continue
102+
95103
print(' -', tokenizer_name)
96104

97105
try:

0 commit comments

Comments
 (0)