Skip to content

Commit 3a02a52

Browse files
committed
[WIP] Add support for idefics3 (SmolVLM)
1 parent 2c92943 commit 3a02a52

File tree

9 files changed

+501
-45
lines changed

9 files changed

+501
-45
lines changed

src/configs.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ function getNormalizedConfig(config) {
6969
case 'paligemma':
7070
case 'florence2':
7171
case 'llava_onevision':
72+
case 'idefics3':
7273
init_normalized_config = getNormalizedConfig(config.text_config);
7374
break;
7475
case 'moondream1':

src/models.js

Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,41 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
546546
}
547547

548548

549+
550+
function default_merge_input_ids_with_image_features({
551+
image_token_id,
552+
inputs_embeds,
553+
image_features,
554+
input_ids,
555+
attention_mask,
556+
}) {
557+
console.log('input_ids', input_ids)
558+
const image_tokens = input_ids.tolist().map(ids =>
559+
ids.reduce((acc, x, idx) => {
560+
if (x == image_token_id) acc.push(idx);
561+
return acc;
562+
}, [])
563+
);
564+
console.log('image_tokens', image_tokens)
565+
const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
566+
const n_image_features = image_features.dims[0];
567+
if (n_image_tokens !== n_image_features) {
568+
throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
569+
}
570+
571+
// Equivalent to performing a masked_scatter
572+
let img = 0;
573+
for (let i = 0; i < image_tokens.length; ++i) {
574+
const tokens = image_tokens[i];
575+
const embeds = inputs_embeds[i];
576+
for (let j = 0; j < tokens.length; ++j) {
577+
embeds[tokens[j]].data.set(image_features[img++].data)
578+
}
579+
}
580+
return { inputs_embeds, attention_mask }
581+
}
582+
583+
549584
/**
550585
* Forward pass of an image-text-to-text model.
551586
* @param {Object} self The image-text-to-text model model.
@@ -582,11 +617,15 @@ async function imageTextToTextForward(self, {
582617

583618
if (!inputs_embeds) {
584619
// 1. Extract the input embeddings
620+
console.log('before encode_text');
585621
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
622+
console.log('after encode_text', inputs_embeds.dims);
586623

587624
// 2. Possibly, merge text and images
588625
if (pixel_values && input_ids.dims[1] !== 1) {
626+
console.log('before encode_image');
589627
const image_features = await self.encode_image({ pixel_values, ...kwargs });
628+
console.log('after encode_image');
590629

591630
({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
592631
image_features,
@@ -3304,8 +3343,8 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
33043343
export class LlavaPreTrainedModel extends PreTrainedModel {
33053344
forward_params = [
33063345
'input_ids',
3307-
'pixel_values',
33083346
'attention_mask',
3347+
'pixel_values',
33093348
'position_ids',
33103349
'past_key_values',
33113350
];
@@ -3487,6 +3526,46 @@ export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel
34873526
return decoder_outputs;
34883527
}
34893528
}
3529+
3530+
3531+
//////////////////////////////////////////////////
3532+
// Idefics3 Models
3533+
export class Idefics3PreTrainedModel extends PreTrainedModel {
3534+
forward_params = [
3535+
'input_ids',
3536+
'attention_mask',
3537+
'pixel_values',
3538+
'pixel_attention_mask',
3539+
'position_ids',
3540+
'past_key_values',
3541+
];
3542+
}
3543+
3544+
/**
3545+
* The LLAVA model which consists of a vision backbone and a language model.
3546+
*/
3547+
export class Idefics3ForConditionalGeneration extends Idefics3PreTrainedModel {
3548+
3549+
async encode_image({ pixel_values, pixel_attention_mask }) {
3550+
const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, pixel_attention_mask })).image_features;
3551+
return features;
3552+
}
3553+
3554+
_merge_input_ids_with_image_features(kwargs) {
3555+
const vision_hidden_size = kwargs.image_features.dims.at(-1);
3556+
const reshaped_image_hidden_states = kwargs.image_features.view(-1, vision_hidden_size);
3557+
3558+
return default_merge_input_ids_with_image_features({
3559+
// @ts-ignore
3560+
image_token_id: this.config.image_token_id,
3561+
...kwargs,
3562+
image_features: reshaped_image_hidden_states,
3563+
})
3564+
}
3565+
}
3566+
//////////////////////////////////////////////////
3567+
3568+
//////////////////////////////////////////////////
34903569
export class CLIPPreTrainedModel extends PreTrainedModel { }
34913570

34923571
/**
@@ -4280,36 +4359,12 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
42804359
return features;
42814360
}
42824361

4283-
_merge_input_ids_with_image_features({
4284-
inputs_embeds,
4285-
image_features,
4286-
input_ids,
4287-
attention_mask,
4288-
}) {
4362+
_merge_input_ids_with_image_features(kwargs) {
4363+
return default_merge_input_ids_with_image_features({
42894364
// @ts-ignore
4290-
const { image_token_id } = this.config;
4291-
const image_tokens = input_ids.tolist().map(ids =>
4292-
ids.reduce((acc, x, idx) => {
4293-
if (x == image_token_id) acc.push(idx);
4294-
return acc;
4295-
}, [])
4296-
);
4297-
const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
4298-
const n_image_features = image_features.dims[0];
4299-
if (n_image_tokens !== n_image_features) {
4300-
throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
4301-
}
4302-
4303-
// Equivalent to performing a masked_scatter
4304-
let img = 0;
4305-
for (let i = 0; i < image_tokens.length; ++i) {
4306-
const tokens = image_tokens[i];
4307-
const embeds = inputs_embeds[i];
4308-
for (let j = 0; j < tokens.length; ++j) {
4309-
embeds[tokens[j]].data.set(image_features[img++].data)
4310-
}
4311-
}
4312-
return { inputs_embeds, attention_mask }
4365+
image_token_id: this.config.image_token_id,
4366+
...kwargs
4367+
})
43134368
}
43144369

43154370
prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
@@ -6914,6 +6969,7 @@ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
69146969

69156970
const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
69166971
['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
6972+
['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
69176973
]);
69186974

69196975
const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
@@ -6922,6 +6978,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
69226978
['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]],
69236979
['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
69246980
['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]],
6981+
['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
69256982
]);
69266983

69276984
const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
2+
3+
import {
4+
ImageProcessor,
5+
} from "../../base/image_processors_utils.js";
6+
import { cat, full, interpolate_4d } from "../../utils/tensor.js";
7+
8+
export class Idefics3ImageProcessor extends ImageProcessor {
9+
constructor(config) {
10+
super(config);
11+
12+
this.do_image_splitting = config.do_image_splitting ?? true;
13+
this.max_image_size = config.max_image_size;
14+
}
15+
16+
/**
17+
* Calculate size to resize images to, to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
18+
* @param {import('../../utils/tensor.js').Tensor} pixel_values Tensor of the image to resize.
19+
* @param {number} vision_encoder_max_size Maximum size of the output image. If the image is larger than this size,
20+
* it will be split into patches of this size, and the original image will be concatenated with the patches, resized to max_size.
21+
*/
22+
get_resize_for_vision_encoder(pixel_values, vision_encoder_max_size) {
23+
let [height, width] = pixel_values.dims.slice(-2);
24+
25+
const aspect_ratio = width / height;
26+
if (width >= height) {
27+
width = Math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size;
28+
height = Math.floor(width / aspect_ratio);
29+
height = Math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size;
30+
} else {
31+
height = Math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size;
32+
width = Math.floor(height * aspect_ratio);
33+
width = Math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size;
34+
}
35+
return { height, width };
36+
}
37+
38+
// /** @param {RawImage|RawImage[]|RawImage[][]} images */
39+
async _call(images, {
40+
do_image_splitting = null,
41+
return_row_col_info = false,
42+
} = {}) {
43+
// TODO: support 2D RawImages
44+
if (!Array.isArray(images)) {
45+
images = [images];
46+
}
47+
48+
let images_list = await Promise.all(images.map(x => this.preprocess(x)));
49+
50+
// Original sizes of images
51+
const original_sizes = images_list.map(x => x.original_size);
52+
53+
// Reshaped sizes of images, before padding or cropping
54+
const reshaped_input_sizes = images_list.map(x => x.reshaped_input_size);
55+
56+
// Convert images to 4D tensors for easier processing
57+
images_list.forEach(x => x.pixel_values.unsqueeze_(0));
58+
59+
let pixel_values;
60+
let images_list_rows = [];
61+
let images_list_cols = [];
62+
63+
const { longest_edge } = this.max_image_size;
64+
65+
if (do_image_splitting ?? this.do_image_splitting) {
66+
let image_rows = new Array(images_list.length);
67+
let image_cols = new Array(images_list.length);
68+
69+
// We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
70+
images_list = await Promise.all(images_list.map(async (x, i) => {
71+
const new_size = this.get_resize_for_vision_encoder(x.pixel_values, longest_edge);
72+
73+
const resized = await interpolate_4d(x.pixel_values, {
74+
size: [new_size.height, new_size.width],
75+
});
76+
77+
const { frames, num_splits_h, num_splits_w } = await this.split_image(resized, this.max_image_size);
78+
image_rows[i] = num_splits_h;
79+
image_cols[i] = num_splits_w;
80+
return cat(frames, 0);
81+
}));
82+
83+
images_list_rows.push(image_rows);
84+
images_list_cols.push(image_cols);
85+
} else {
86+
/** @type {[number, number]} */
87+
const size = [longest_edge, longest_edge];
88+
images_list = await Promise.all(
89+
images_list.map(x => interpolate_4d(x.pixel_values, { size }))
90+
);
91+
92+
images_list_rows.push(new Array(images_list.length).fill(0));
93+
images_list_cols.push(new Array(images_list.length).fill(0));
94+
}
95+
96+
// Stack pixel values
97+
// TODO: support 2D images inputs
98+
pixel_values = cat(images_list, 0);
99+
pixel_values.unsqueeze_(0);
100+
101+
// TODO: Improve pixel_attention_mask
102+
const [b, n, c, h, w] = pixel_values.dims;
103+
const pixel_attention_mask = full([b, n, h, w], true);
104+
105+
return {
106+
pixel_values,
107+
pixel_attention_mask,
108+
109+
original_sizes,
110+
reshaped_input_sizes,
111+
...(
112+
return_row_col_info
113+
? { rows: images_list_rows, cols: images_list_cols }
114+
: {}
115+
),
116+
}
117+
}
118+
119+
async split_image(pixel_values, { longest_edge }) {
120+
const max_height = longest_edge;
121+
const max_width = longest_edge;
122+
123+
const frames = [];
124+
125+
const [height, width] = pixel_values.dims.slice(-2);
126+
127+
let num_splits_h = 0, num_splits_w = 0;
128+
129+
if (height > max_height || width > max_width) {
130+
// Calculate the number of splits
131+
num_splits_h = Math.ceil(height / max_height);
132+
num_splits_w = Math.ceil(width / max_width);
133+
134+
// Calculate the optimal width and height for the sub-images
135+
const optimal_height = Math.ceil(height / num_splits_h);
136+
const optimal_width = Math.ceil(width / num_splits_w);
137+
138+
// Iterate through each row and column
139+
for (let r = 0; r < num_splits_h; r++) {
140+
for (let c = 0; c < num_splits_w; c++) {
141+
// Calculate the starting point of the crop
142+
const start_x = c * optimal_width;
143+
const start_y = r * optimal_height;
144+
145+
// Calculate the ending point of the crop
146+
const end_x = Math.min(start_x + optimal_width, width);
147+
const end_y = Math.min(start_y + optimal_height, height);
148+
149+
// Crop the image
150+
frames.push(pixel_values.slice(null, null, [start_y, end_y], [start_x, end_x]));
151+
}
152+
}
153+
154+
// Resize the global image to match max dimensions for memory efficiency
155+
const global_image_height = max_height;
156+
const global_image_width = max_width;
157+
158+
if (height !== global_image_height || width !== global_image_width) {
159+
pixel_values = await interpolate_4d(pixel_values, {
160+
size: [global_image_height, global_image_width],
161+
})
162+
}
163+
}
164+
165+
frames.push(pixel_values);
166+
167+
return { frames, num_splits_h, num_splits_w };
168+
}
169+
}

0 commit comments

Comments
 (0)