Skip to content

Commit 6f27a10

Browse files
authored
Add support for PaliGemma (& PaliGemma2) (#1074)
* Bump versions * Add support for PaliGemma (&PaliGemma2) * Add unit tests * Remove debug line * Revert version bump (move to new PR)
1 parent e4dac8a commit 6f27a10

File tree

8 files changed

+213
-5
lines changed

8 files changed

+213
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
376376
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
377377
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
378378
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
379+
1. **[PaliGemma](https://huggingface.co/docs/transformers/main/model_doc/paligemma)** (from Google) released with the papers [PaliGemma: A versatile 3B VLM for transfer](https://arxiv.org/abs/2407.07726) and [PaliGemma 2: A Family of Versatile VLMs for Transfer](https://arxiv.org/abs/2412.03555) by the PaliGemma Google team.
379380
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** (from IBM) released with the paper [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/abs/2306.09364) by Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
380381
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from Princeton University, IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/abs/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
381382
1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://arxiv.org/abs/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee.

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
9292
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
9393
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
94+
1. **[PaliGemma](https://huggingface.co/docs/transformers/main/model_doc/paligemma)** (from Google) released with the papers [PaliGemma: A versatile 3B VLM for transfer](https://arxiv.org/abs/2407.07726) and [PaliGemma 2: A Family of Versatile VLMs for Transfer](https://arxiv.org/abs/2412.03555) by the PaliGemma Google team.
9495
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** (from IBM) released with the paper [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/abs/2306.09364) by Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
9596
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from Princeton University, IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/abs/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
9697
1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://arxiv.org/abs/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee.

src/models.js

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,9 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
558558
new_model_inputs.use_cache_branch = boolTensor(!!past_key_values);
559559
}
560560
if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) {
561-
new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values);
561+
// NOTE: Handle a special case for paligemma models, where positions are 1-indexed
562+
const start_index = self.config.model_type === 'paligemma' ? 1 : 0;
563+
new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values, start_index);
562564
}
563565

564566
// Unpack the `past_key_values` object into model inputs
@@ -694,14 +696,14 @@ async function imageTextToTextForward(self, {
694696
* @param {Tensor} attention_mask
695697
* @returns {{data: BigInt64Array, dims: number[]}}
696698
*/
697-
function cumsum_masked_fill(attention_mask) {
699+
function cumsum_masked_fill(attention_mask, start_index = 0) {
698700
const [bz, seq_len] = attention_mask.dims;
699701
const attn_mask_data = attention_mask.data;
700702

701703
const data = new BigInt64Array(attn_mask_data.length);
702704
for (let i = 0; i < bz; ++i) {
703705
const start = i * seq_len;
704-
let sum = BigInt(0);
706+
let sum = BigInt(start_index);
705707
for (let j = 0; j < seq_len; ++j) {
706708
const index = start + j;
707709
if (attn_mask_data[index] === 0n) {
@@ -728,10 +730,10 @@ function cumsum_masked_fill(attention_mask) {
728730
* position_ids = position_ids[:, -input_ids.shape[1] :]
729731
* ```
730732
*/
731-
function createPositionIds(model_inputs, past_key_values = null) {
733+
function createPositionIds(model_inputs, past_key_values = null, start_index = 0) {
732734
const { input_ids, inputs_embeds, attention_mask } = model_inputs;
733735

734-
const { data, dims } = cumsum_masked_fill(attention_mask);
736+
const { data, dims } = cumsum_masked_fill(attention_mask, start_index);
735737
let position_ids = new Tensor('int64', data, dims);
736738
if (past_key_values) {
737739
const offset = -(input_ids ?? inputs_embeds).dims.at(1);
@@ -3548,6 +3550,30 @@ export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel
35483550
}
35493551
}
35503552

3553+
export class PaliGemmaPreTrainedModel extends PreTrainedModel {
3554+
forward_params = [
3555+
'input_ids',
3556+
// 'inputs_embeds',
3557+
'attention_mask',
3558+
'pixel_values',
3559+
'position_ids',
3560+
'past_key_values',
3561+
];
3562+
}
3563+
3564+
export class PaliGemmaForConditionalGeneration extends PaliGemmaPreTrainedModel {
3565+
_merge_input_ids_with_image_features(kwargs) {
3566+
const vision_hidden_size = kwargs.image_features.dims.at(-1);
3567+
const reshaped_image_hidden_states = kwargs.image_features.view(-1, vision_hidden_size);
3568+
3569+
return default_merge_input_ids_with_image_features({
3570+
// @ts-ignore
3571+
image_token_id: this.config.image_token_index,
3572+
...kwargs,
3573+
image_features: reshaped_image_hidden_states,
3574+
})
3575+
}
3576+
}
35513577

35523578
//////////////////////////////////////////////////
35533579
// Idefics3 Models
@@ -7015,6 +7041,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
70157041
['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
70167042
['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]],
70177043
['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
7044+
['paligemma', ['PaliGemmaForConditionalGeneration', PaliGemmaForConditionalGeneration]],
70187045
]);
70197046

70207047
const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import { Processor } from "../../base/processing_utils.js";
2+
import { AutoImageProcessor } from "../auto/image_processing_auto.js";
3+
import { AutoTokenizer } from "../../tokenizers.js";
4+
5+
const IMAGE_TOKEN = "<image>";
6+
7+
function build_string_from_input(
8+
prompt,
9+
bos_token,
10+
image_seq_len,
11+
image_token,
12+
num_images,
13+
) {
14+
return `${image_token.repeat(image_seq_len * num_images)}${bos_token}${prompt}\n`
15+
}
16+
17+
export class PaliGemmaProcessor extends Processor {
18+
static tokenizer_class = AutoTokenizer
19+
static image_processor_class = AutoImageProcessor
20+
static uses_processor_config = false;
21+
22+
/**
23+
* @typedef {import('../../utils/image.js').RawImage} RawImage
24+
*/
25+
26+
// `images` is required, `text` is optional
27+
async _call(/** @type {RawImage|RawImage[]} */ images, text = null, kwargs = {}) {
28+
if (!text) {
29+
console.warn(
30+
"You are using PaliGemma without a text prefix. It will perform as a picture-captioning model."
31+
)
32+
text = ""
33+
}
34+
35+
if (!Array.isArray(images)) {
36+
images = [images]
37+
}
38+
39+
if (!Array.isArray(text)) {
40+
text = [text]
41+
}
42+
43+
const bos_token = this.tokenizer.bos_token;
44+
const image_seq_length = this.image_processor.config.image_seq_length;
45+
let input_strings;
46+
if (text.some((t) => t.includes(IMAGE_TOKEN))) {
47+
input_strings = text.map(
48+
sample => {
49+
const expanded_sample = sample.replaceAll(IMAGE_TOKEN, IMAGE_TOKEN.repeat(image_seq_length));
50+
const bos_rfind_index = expanded_sample.lastIndexOf(IMAGE_TOKEN);
51+
const bos_index = bos_rfind_index === -1 ? 0 : bos_rfind_index + IMAGE_TOKEN.length;
52+
return expanded_sample.slice(0, bos_index) + bos_token + expanded_sample.slice(bos_index) + "\n";
53+
}
54+
)
55+
} else {
56+
console.warn(
57+
"You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special " +
58+
"image tokens in the text, as many tokens as there are images per each text. It is recommended to " +
59+
"add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images " +
60+
"each text has and add special tokens."
61+
)
62+
63+
input_strings = text.map(
64+
sample => build_string_from_input(
65+
sample,
66+
bos_token,
67+
image_seq_length,
68+
IMAGE_TOKEN,
69+
images.length,
70+
)
71+
)
72+
}
73+
74+
const text_inputs = this.tokenizer(input_strings, kwargs);
75+
const image_inputs = await this.image_processor(images, kwargs);
76+
77+
return {
78+
...image_inputs,
79+
...text_inputs,
80+
}
81+
}
82+
}

src/models/processors.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ export * from './idefics3/processing_idefics3.js';
44
export * from './janus/processing_janus.js';
55
export * from './jina_clip/processing_jina_clip.js';
66
export * from './owlvit/processing_owlvit.js';
7+
export * from './paligemma/processing_paligemma.js';
78
export * from './pyannote/processing_pyannote.js';
89
export * from './qwen2_vl/processing_qwen2_vl.js';
910
export * from './sam/processing_sam.js';

src/tokenizers.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2605,6 +2605,12 @@ export class PreTrainedTokenizer extends Callable {
26052605
this.unk_token = this.getToken('unk_token');
26062606
this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token);
26072607

2608+
this.bos_token = this.getToken('bos_token');
2609+
this.bos_token_id = this.model.tokens_to_ids.get(this.bos_token);
2610+
2611+
this.eos_token = this.getToken('eos_token');
2612+
this.eos_token_id = this.model.tokens_to_ids.get(this.eos_token);
2613+
26082614
this.model_max_length = tokenizerConfig.model_max_length;
26092615

26102616
/** @type {boolean} Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). */

tests/processors.test.js

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ const MODELS = {
4848
florence2: "Xenova/tiny-random-Florence2ForConditionalGeneration",
4949
qwen2_vl: "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration",
5050
idefics3: "hf-internal-testing/tiny-random-Idefics3ForConditionalGeneration",
51+
paligemma: "hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration",
5152
};
5253

5354
const BASE_URL = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/";
@@ -1196,5 +1197,40 @@ describe("Processors", () => {
11961197
},
11971198
MAX_TEST_TIME,
11981199
);
1200+
1201+
describe(
1202+
"PaliGemmaProcessor",
1203+
() => {
1204+
/** @type {import('../src/transformers.js').PaliGemmaProcessor} */
1205+
let processor;
1206+
let images = {};
1207+
1208+
beforeAll(async () => {
1209+
processor = await AutoProcessor.from_pretrained(MODELS.paligemma);
1210+
images = {
1211+
white_image: await load_image(TEST_IMAGES.white_image),
1212+
};
1213+
});
1214+
1215+
it("Image-only (default text)", async () => {
1216+
const { input_ids, pixel_values } = await processor(images.white_image);
1217+
compare(input_ids.dims, [1, 258]);
1218+
compare(pixel_values.dims, [1, 3, 224, 224]);
1219+
});
1220+
1221+
it("Single image & text", async () => {
1222+
const { input_ids, pixel_values } = await processor(images.white_image, "<image>What is on the flower?");
1223+
compare(input_ids.dims, [1, 264]);
1224+
compare(pixel_values.dims, [1, 3, 224, 224]);
1225+
});
1226+
1227+
it("Multiple images & text", async () => {
1228+
const { input_ids, pixel_values } = await processor([images.white_image, images.white_image], "<image><image>Describe the images.");
1229+
compare(input_ids.dims, [1, 518]);
1230+
compare(pixel_values.dims, [2, 3, 224, 224]);
1231+
});
1232+
},
1233+
MAX_TEST_TIME,
1234+
);
11991235
});
12001236
});

tests/tiny_random.test.js

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
Processor,
2121
Florence2Processor,
2222
Idefics3Processor,
23+
PaliGemmaProcessor,
2324

2425
// Models
2526
LlamaForCausalLM,
@@ -54,6 +55,7 @@ import {
5455
VisionEncoderDecoderModel,
5556
Florence2ForConditionalGeneration,
5657
Qwen2VLForConditionalGeneration,
58+
PaliGemmaForConditionalGeneration,
5759
MarianMTModel,
5860
PatchTSTModel,
5961
PatchTSTForPrediction,
@@ -1072,6 +1074,58 @@ describe("Tiny random models", () => {
10721074
});
10731075
});
10741076

1077+
describe("paligemma", () => {
1078+
const text = "<image>What is on the flower?";
1079+
1080+
// Empty white image
1081+
const dims = [224, 224, 3];
1082+
const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims);
1083+
1084+
describe("PaliGemmaForConditionalGeneration", () => {
1085+
const model_id = "hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration";
1086+
1087+
/** @type {PaliGemmaForConditionalGeneration} */
1088+
let model;
1089+
/** @type {PaliGemmaProcessor} */
1090+
let processor;
1091+
beforeAll(async () => {
1092+
model = await PaliGemmaForConditionalGeneration.from_pretrained(model_id, {
1093+
// TODO move to config
1094+
...DEFAULT_MODEL_OPTIONS,
1095+
});
1096+
processor = await AutoProcessor.from_pretrained(model_id);
1097+
}, MAX_MODEL_LOAD_TIME);
1098+
1099+
it(
1100+
"forward",
1101+
async () => {
1102+
const inputs = await processor(image, text);
1103+
1104+
const { logits } = await model(inputs);
1105+
expect(logits.dims).toEqual([1, 264, 257216]);
1106+
expect(logits.mean().item()).toBeCloseTo(-0.0023024685215204954, 6);
1107+
},
1108+
MAX_TEST_EXECUTION_TIME,
1109+
);
1110+
1111+
it(
1112+
"batch_size=1",
1113+
async () => {
1114+
const inputs = await processor(image, text);
1115+
const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 });
1116+
1117+
const new_tokens = generate_ids.slice(null, [inputs.input_ids.dims.at(-1), null]);
1118+
expect(new_tokens.tolist()).toEqual([[91711n, 24904n, 144054n, 124983n, 83862n, 124983n, 124983n, 124983n, 141236n, 124983n]]);
1119+
},
1120+
MAX_TEST_EXECUTION_TIME,
1121+
);
1122+
1123+
afterAll(async () => {
1124+
await model?.dispose();
1125+
}, MAX_MODEL_DISPOSE_TIME);
1126+
});
1127+
});
1128+
10751129
describe("vision-encoder-decoder", () => {
10761130
describe("VisionEncoderDecoderModel", () => {
10771131
const model_id = "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2";

0 commit comments

Comments
 (0)