Skip to content

Commit 5ddc472

Browse files
authored
Add support for nougat models (image-to-text) (#391)
* Add `NougatTokenizer` * Add nougat unit tests * Add support for `NougatImageProcessor` * Add `crop` function to `RawImage` * Fix `RawImage` save function OffscreenCanvas does not have `toDataURL` function * Add listed support for nougat models * Fix `min`/`max` function typing * Add unknown token to tokenizer class * Implement `NoBadWordsLogitsProcessor` * Use `NoBadWordsLogitsProcessor` in `generate` * Fix regex group substitutions Python uses \1, \2, etc. for group substitutions, but JavaScript uses $1, $2, etc. * Create `regexSplit` helper function to split but keep delimiter * Fix splitting for String pattern types * Fix docstring
1 parent 7cf8a2c commit 5ddc472

File tree

10 files changed

+285
-24
lines changed

10 files changed

+285
-24
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
300300
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
301301
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
302302
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
303+
1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic.
303304
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
304305
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
305306
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
4242
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
4343
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
44+
1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic.
4445
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
4546
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
4647
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.

scripts/supported_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@
348348
'google/mt5-small',
349349
'google/mt5-base',
350350
],
351+
'nougat': [
352+
# Image-to-text
353+
'facebook/nougat-small',
354+
'facebook/nougat-base',
355+
],
351356
'opt': [
352357
# Text generation
353358
'facebook/opt-125m',

src/models.js

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ import {
6868
WhisperTimeStampLogitsProcessor,
6969
NoRepeatNGramLogitsProcessor,
7070
RepetitionPenaltyLogitsProcessor,
71+
NoBadWordsLogitsProcessor,
7172
MinLengthLogitsProcessor,
7273
MinNewTokensLengthLogitsProcessor,
7374

@@ -857,9 +858,9 @@ export class PreTrainedModel extends Callable {
857858
// }
858859
// }
859860

860-
// if (generation_config.bad_words_ids !== null) {
861-
// processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
862-
// }
861+
if (generation_config.bad_words_ids !== null) {
862+
processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
863+
}
863864

864865
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
865866
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));

src/processors.js

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import {
3030
} from './utils/hub.js';
3131

3232
import {
33+
min,
3334
max,
3435
softmax,
3536
FFT,
@@ -207,6 +208,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
207208
this.do_center_crop = this.config.do_center_crop;
208209
this.crop_size = this.config.crop_size;
209210
this.do_convert_rgb = this.config.do_convert_rgb ?? true;
211+
this.do_crop_margin = this.config.do_crop_margin;
210212

211213
this.pad_size = this.config.pad_size;
212214
this.do_pad = this.config.do_pad;
@@ -249,6 +251,44 @@ export class ImageFeatureExtractor extends FeatureExtractor {
249251
}
250252

251253

254+
/**
255+
* Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the threshold).
256+
* @param {RawImage} image The image to be cropped.
257+
* @param {number} gray_threshold Value below which pixels are considered to be gray.
258+
* @returns {Promise<RawImage>} The cropped image.
259+
*/
260+
async crop_margin(image, gray_threshold = 200) {
261+
262+
const gray_image = image.clone().grayscale();
263+
264+
const minValue = min(gray_image.data)[0];
265+
const maxValue = max(gray_image.data)[0];
266+
const diff = maxValue - minValue;
267+
268+
if (diff === 0) {
269+
return image;
270+
}
271+
272+
const threshold = gray_threshold / 255;
273+
274+
let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0;
275+
for (let j = 0; j < gray_image.height; ++j) {
276+
const row = j * gray_image.width;
277+
for (let i = 0; i < gray_image.width; ++i) {
278+
if ((gray_image.data[row + i] - minValue) / diff < threshold) {
279+
// We have a non-zero pixel, so we update the min/max values accordingly
280+
x_min = Math.min(x_min, i);
281+
y_min = Math.min(y_min, j);
282+
x_max = Math.max(x_max, i);
283+
y_max = Math.max(y_max, j);
284+
}
285+
}
286+
}
287+
288+
image = await image.crop([x_min, y_min, x_max, y_max]);
289+
return image;
290+
}
291+
252292
/**
253293
* Pad the image by a certain amount.
254294
* @param {Float32Array} pixelData The pixel data to pad.
@@ -279,7 +319,12 @@ export class ImageFeatureExtractor extends FeatureExtractor {
279319
// Only add padding if there is a difference in size
280320
if (paddedImageWidth !== imageWidth || paddedImageHeight !== imageHeight) {
281321
const paddedPixelData = new Float32Array(paddedImageWidth * paddedImageHeight * imageChannels);
282-
if (constant_values !== 0) {
322+
if (Array.isArray(constant_values)) {
323+
// Fill with constant values, cycling through the array
324+
for (let i = 0; i < paddedPixelData.length; ++i) {
325+
paddedPixelData[i] = constant_values[i % imageChannels];
326+
}
327+
} else if (constant_values !== 0) {
283328
paddedPixelData.fill(constant_values);
284329
}
285330

@@ -347,15 +392,21 @@ export class ImageFeatureExtractor extends FeatureExtractor {
347392
*/
348393
async preprocess(image) {
349394

350-
// First, convert image to RGB if specified in config.
351-
if (this.do_convert_rgb) {
352-
image = image.rgb();
395+
if (this.do_crop_margin) {
396+
// NOTE: Specific to nougat processors. This is done before resizing,
397+
// and can be interpreted as a pre-preprocessing step.
398+
image = await this.crop_margin(image);
353399
}
354400

355401
const srcWidth = image.width; // original width
356402
const srcHeight = image.height; // original height
357403

358-
// Next, resize all images
404+
// Convert image to RGB if specified in config.
405+
if (this.do_convert_rgb) {
406+
image = image.rgb();
407+
}
408+
409+
// Resize all images
359410
if (this.do_resize) {
360411
// TODO:
361412
// For efficiency reasons, it might be best to merge the resize and center crop operations into one.
@@ -541,17 +592,31 @@ export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
541592
export class BeitFeatureExtractor extends ImageFeatureExtractor { }
542593
export class DonutFeatureExtractor extends ImageFeatureExtractor {
543594
pad_image(pixelData, imgDims, padSize, options = {}) {
595+
const [imageWidth, imageHeight, imageChannels] = imgDims;
596+
597+
let image_mean = this.image_mean;
598+
if (!Array.isArray(this.image_mean)) {
599+
image_mean = new Array(imageChannels).fill(image_mean);
600+
}
601+
602+
let image_std = this.image_std;
603+
if (!Array.isArray(this.image_std)) {
604+
image_std = new Array(imageChannels).fill(image_mean);
605+
}
606+
607+
const constant_values = image_mean.map((x, i) => - x / this.image_std[i]);
608+
544609
return super.pad_image(pixelData, imgDims, padSize, {
545610
center: true,
546611

547-
// Since normalization is done after padding, we need to pad with -1.
548-
// NOTE: This only works if `image_mean = 0.5` and `image_std = 0.5`.
612+
// Since normalization is done after padding, we need to use certain constant values to ensure the same behaviour is observed.
549613
// For more information, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/image_processing_donut.py#L433-L451
550-
constant_values: -1,
614+
constant_values: constant_values,
551615
...options,
552616
});
553617
}
554618
}
619+
export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor
555620

556621
/**
557622
* @typedef {object} DetrFeatureExtractorResultProps
@@ -1573,6 +1638,7 @@ export class AutoProcessor {
15731638
DetrFeatureExtractor,
15741639
YolosFeatureExtractor,
15751640
DonutFeatureExtractor,
1641+
NougatImageProcessor,
15761642

15771643
SamImageProcessor,
15781644
Swin2SRImageProcessor,

src/tokenizers.js

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,56 @@ async function loadTokenizer(pretrained_model_name_or_path, options) {
5656
return info;
5757
}
5858

59+
60+
/**
61+
* Helper function to split a string on a regex, but keep the delimiters.
62+
* This is required, because the JavaScript `.split()` method does not keep the delimiters,
63+
* and wrapping in a capturing group causes issues with existing capturing groups (due to nesting).
64+
* @param {string} text The text to split.
65+
* @param {RegExp} regex The regex to split on.
66+
* @returns {string[]} The split string.
67+
*/
68+
function regexSplit(text, regex) {
69+
const result = [];
70+
let prev = 0;
71+
for (const match of text.matchAll(regex)) {
72+
const fullMatch = match[0];
73+
if (prev < match.index) {
74+
result.push(text.slice(prev, match.index));
75+
}
76+
if (fullMatch.length > 0) {
77+
result.push(fullMatch);
78+
}
79+
prev = match.index + fullMatch.length;
80+
}
81+
if (prev < text.length) {
82+
result.push(text.slice(prev));
83+
}
84+
return result;
85+
}
86+
87+
5988
/**
6089
* Helper method to construct a pattern from a config object.
6190
* @param {Object} pattern The pattern object.
62-
* @param {boolean} invert Whether to invert the pattern (only applicable for Regex patterns).
63-
* @returns {RegExp|string|null} The compiled pattern.
91+
* @param {boolean} invert Whether to invert the pattern.
92+
* @returns {RegExp|null} The compiled pattern.
6493
*/
6594
function createPattern(pattern, invert = true) {
6695

6796
if (pattern.Regex !== undefined) {
68-
// NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split()
69-
return new RegExp(invert ? pattern.Regex : `(${pattern.Regex})`, 'gu');
97+
// In certain cases, the pattern may contain unnecessary escape sequences (e.g., \# or \& or \~).
98+
// i.e., valid in Python (where the patterns are exported from) but invalid in JavaScript (where the patterns are parsed).
99+
// This isn't an issue when creating the regex w/o the 'u' flag, but it is when the 'u' flag is used.
100+
// For this reason, it is necessary to remove these backslashes before creating the regex.
101+
// See https://stackoverflow.com/a/63007777/13989043 for more information
102+
const regex = pattern.Regex.replace(/\\([#&~])/g, '$1'); // TODO: add more characters to this list if necessary
103+
return new RegExp(regex, 'gu');
70104

71105
} else if (pattern.String !== undefined) {
72-
return pattern.String;
106+
const escaped = escapeRegExp(pattern.String);
107+
// NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split()
108+
return new RegExp(invert ? escaped : `(${escaped})`, 'gu');
73109

74110
} else {
75111
console.warn('Unknown pattern type:', pattern)
@@ -813,6 +849,8 @@ class Normalizer extends Callable {
813849
return new Replace(config);
814850
case 'NFC':
815851
return new NFC(config);
852+
case 'NFKC':
853+
return new NFKC(config);
816854
case 'NFKD':
817855
return new NFKD(config);
818856
case 'Strip':
@@ -888,6 +926,21 @@ class NFC extends Normalizer {
888926
}
889927
}
890928

929+
/**
930+
* NFKC Normalizer.
931+
* @extends Normalizer
932+
*/
933+
class NFKC extends Normalizer {
934+
/**
935+
* Normalize text using NFKC normalization.
936+
* @param {string} text The text to be normalized.
937+
* @returns {string} The normalized text.
938+
*/
939+
normalize(text) {
940+
text = text.normalize('NFKC')
941+
return text;
942+
}
943+
}
891944
/**
892945
* NFKD Normalizer.
893946
* @extends Normalizer
@@ -1299,7 +1352,7 @@ class SplitPreTokenizer extends PreTokenizer {
12991352
if (this.config.invert) {
13001353
return text.match(this.pattern) || [];
13011354
} else {
1302-
return text.split(this.pattern).filter(x => x);
1355+
return regexSplit(text, this.pattern);
13031356
}
13041357
}
13051358
}
@@ -2190,6 +2243,9 @@ export class PreTrainedTokenizer extends Callable {
21902243
this.sep_token = this.getToken(tokenizerConfig, 'sep_token');
21912244
this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token);
21922245

2246+
this.unk_token = this.getToken(tokenizerConfig, 'unk_token');
2247+
this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token);
2248+
21932249
this.model_max_length = tokenizerConfig.model_max_length;
21942250

21952251
/** @type {boolean} Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). */
@@ -3756,6 +3812,8 @@ export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { }
37563812

37573813
export class SpeechT5Tokenizer extends PreTrainedTokenizer { }
37583814

3815+
export class NougatTokenizer extends PreTrainedTokenizer { }
3816+
37593817
/**
37603818
* Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function.
37613819
* The chosen tokenizer class is determined by the type specified in the tokenizer config.
@@ -3798,6 +3856,7 @@ export class AutoTokenizer {
37983856
BlenderbotTokenizer,
37993857
BlenderbotSmallTokenizer,
38003858
SpeechT5Tokenizer,
3859+
NougatTokenizer,
38013860

38023861
// Base case:
38033862
PreTrainedTokenizer,

src/utils/generation.js

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,49 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
491491
}
492492
}
493493

494+
export class NoBadWordsLogitsProcessor extends LogitsProcessor {
495+
/**
496+
* Create a `NoBadWordsLogitsProcessor`.
497+
* @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
498+
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
499+
*/
500+
constructor(bad_words_ids, eos_token_id) {
501+
super();
502+
this.bad_words_ids = bad_words_ids;
503+
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
504+
}
505+
506+
/**
507+
* Apply logit processor.
508+
* @param {Array} input_ids The input IDs.
509+
* @param {Object} logits The logits.
510+
* @returns {Object} The processed logits.
511+
*/
512+
_call(input_ids, logits) {
513+
514+
for (const bad_word_ids of this.bad_words_ids) {
515+
// Whether to modify the logits of the last token in the bad word id sequence
516+
let mark = true;
517+
518+
// For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
519+
// then we set the logits of the last bad word id to -Infinity.
520+
for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) {
521+
522+
if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) {
523+
// We have found a mismatch
524+
mark = false;
525+
break;
526+
}
527+
}
528+
if (mark) {
529+
logits.data[bad_word_ids.at(-1)] = -Infinity;
530+
}
531+
}
532+
533+
return logits
534+
}
535+
}
536+
494537
/**
495538
* Class that holds a configuration for a generation task.
496539
*/

0 commit comments

Comments
 (0)