Skip to content

Commit f9c59f4

Browse files
committed
Add support for batched 2d images in idefics3 processor
1 parent 832a7ce commit f9c59f4

File tree

2 files changed

+98
-49
lines changed

2 files changed

+98
-49
lines changed

src/models/idefics3/image_processing_idefics3.js

Lines changed: 97 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import {
44
ImageProcessor,
55
} from "../../base/image_processors_utils.js";
6-
import { cat, full, interpolate_4d } from "../../utils/tensor.js";
6+
import { cat, full, interpolate_4d, stack } from "../../utils/tensor.js";
77

88
export class Idefics3ImageProcessor extends ImageProcessor {
99
constructor(config) {
@@ -13,9 +13,14 @@ export class Idefics3ImageProcessor extends ImageProcessor {
1313
this.max_image_size = config.max_image_size;
1414
}
1515

16+
/**
17+
* @typedef {import('../../utils/image.js').RawImage} RawImage
18+
* @typedef {import('../../utils/tensor.js').Tensor} Tensor
19+
*/
20+
1621
/**
1722
* Calculate size to resize images to, to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
18-
* @param {import('../../utils/tensor.js').Tensor} pixel_values Tensor of the image to resize.
23+
* @param {Tensor} pixel_values Tensor of the image to resize.
1924
* @param {number} vision_encoder_max_size Maximum size of the output image. If the image is larger than this size,
2025
* it will be split into patches of this size, and the original image will be concatenated with the patches, resized to max_size.
2126
*/
@@ -35,72 +40,116 @@ export class Idefics3ImageProcessor extends ImageProcessor {
3540
return { height, width };
3641
}
3742

38-
// /** @param {RawImage|RawImage[]|RawImage[][]} images */
43+
/** @param {RawImage|RawImage[]|RawImage[][]} images */
3944
async _call(images, {
4045
do_image_splitting = null,
4146
return_row_col_info = false,
4247
} = {}) {
43-
// TODO: support 2D RawImages
48+
49+
/** @type {RawImage[][]} */
50+
let batched_2d_images;
4451
if (!Array.isArray(images)) {
45-
images = [images];
52+
batched_2d_images = [[images]];
53+
} else {
54+
if (images.length === 0 || !images[0]) {
55+
throw new Error("No images provided.");
56+
}
57+
if (!Array.isArray(images[0])) {
58+
batched_2d_images = [/** @type {RawImage[]} */(images)];
59+
} else {
60+
batched_2d_images = /** @type {RawImage[][]} */(images);
61+
}
4662
}
4763

48-
let images_list = await Promise.all(images.map(x => this.preprocess(x)));
64+
// List of tensors, each with shape [patches, channels, height, width]
65+
let all_pixel_values = [];
66+
let images_list_rows = [];
67+
let images_list_cols = [];
4968

50-
// Original sizes of images
51-
const original_sizes = images_list.map(x => x.original_size);
69+
const original_sizes = [];
70+
const reshaped_input_sizes = [];
71+
for (const image_batch of batched_2d_images) {
5272

53-
// Reshaped sizes of images, before padding or cropping
54-
const reshaped_input_sizes = images_list.map(x => x.reshaped_input_size);
73+
let images_list = await Promise.all(image_batch.map(x => this.preprocess(x)));
5574

56-
// Convert images to 4D tensors for easier processing
57-
images_list.forEach(x => x.pixel_values.unsqueeze_(0));
75+
// Original sizes of images
76+
original_sizes.push(...images_list.map(x => x.original_size));
5877

59-
let pixel_values;
60-
let images_list_rows = [];
61-
let images_list_cols = [];
78+
// Reshaped sizes of images, before padding or cropping
79+
reshaped_input_sizes.push(...images_list.map(x => x.reshaped_input_size));
6280

63-
const { longest_edge } = this.max_image_size;
81+
// Convert images to 4D tensors for easier processing
82+
images_list.forEach(x => x.pixel_values.unsqueeze_(0));
6483

65-
if (do_image_splitting ?? this.do_image_splitting) {
66-
let image_rows = new Array(images_list.length);
67-
let image_cols = new Array(images_list.length);
84+
const { longest_edge } = this.max_image_size;
6885

69-
// We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
70-
images_list = await Promise.all(images_list.map(async (x, i) => {
71-
const new_size = this.get_resize_for_vision_encoder(x.pixel_values, longest_edge);
86+
/** @type {Tensor[]} */
87+
let images_tensor;
88+
if (do_image_splitting ?? this.do_image_splitting) {
89+
let image_rows = new Array(images_list.length);
90+
let image_cols = new Array(images_list.length);
7291

73-
const resized = await interpolate_4d(x.pixel_values, {
74-
size: [new_size.height, new_size.width],
75-
});
92+
// We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
93+
images_tensor = await Promise.all(images_list.map(async (x, i) => {
94+
const new_size = this.get_resize_for_vision_encoder(x.pixel_values, longest_edge);
7695

77-
const { frames, num_splits_h, num_splits_w } = await this.split_image(resized, this.max_image_size);
78-
image_rows[i] = num_splits_h;
79-
image_cols[i] = num_splits_w;
80-
return cat(frames, 0);
81-
}));
96+
const resized = await interpolate_4d(x.pixel_values, {
97+
size: [new_size.height, new_size.width],
98+
});
8299

83-
images_list_rows.push(image_rows);
84-
images_list_cols.push(image_cols);
85-
} else {
86-
/** @type {[number, number]} */
87-
const size = [longest_edge, longest_edge];
88-
images_list = await Promise.all(
89-
images_list.map(x => interpolate_4d(x.pixel_values, { size }))
90-
);
91-
92-
images_list_rows.push(new Array(images_list.length).fill(0));
93-
images_list_cols.push(new Array(images_list.length).fill(0));
100+
const { frames, num_splits_h, num_splits_w } = await this.split_image(resized, this.max_image_size);
101+
image_rows[i] = num_splits_h;
102+
image_cols[i] = num_splits_w;
103+
return cat(frames, 0);
104+
}));
105+
106+
images_list_rows.push(image_rows);
107+
images_list_cols.push(image_cols);
108+
109+
} else {
110+
/** @type {[number, number]} */
111+
const size = [longest_edge, longest_edge];
112+
images_tensor = await Promise.all(
113+
images_list.map(x => interpolate_4d(x.pixel_values, { size }))
114+
);
115+
116+
images_list_rows.push(new Array(images_list.length).fill(0));
117+
images_list_cols.push(new Array(images_list.length).fill(0));
118+
}
119+
120+
all_pixel_values.push(cat(images_tensor, 0));
94121
}
95122

96123
// Stack pixel values
97-
// TODO: support 2D images inputs
98-
pixel_values = cat(images_list, 0);
99-
pixel_values.unsqueeze_(0);
100-
101-
// TODO: Improve pixel_attention_mask
102-
const [b, n, c, h, w] = pixel_values.dims;
103-
const pixel_attention_mask = full([b, n, h, w], true);
124+
let pixel_values;
125+
let pixel_attention_mask;
126+
if (all_pixel_values.length === 1) {
127+
pixel_values = all_pixel_values[0];
128+
pixel_values.unsqueeze_(0);
129+
} else {
130+
// Add padding (if necessary) to images with less patches than the maximum number of patches
131+
const max_num_patches = Math.max(...all_pixel_values.map(x => x.dims.at(0)));
132+
133+
const [c, h, w] = all_pixel_values[0].dims.slice(1);
134+
135+
pixel_attention_mask = full([all_pixel_values.length, max_num_patches, h, w], 1);
136+
const pixel_attention_mask_data = pixel_attention_mask.data;
137+
const pixel_attention_mask_stride = max_num_patches * h * w;
138+
for (let i = 0; i < all_pixel_values.length; ++i) {
139+
const num_patches = all_pixel_values[i].dims[0];
140+
if (num_patches < max_num_patches) {
141+
all_pixel_values[i] = cat([
142+
all_pixel_values[i],
143+
full([max_num_patches - num_patches, c, h, w], 0),
144+
], 0);
145+
146+
const start_offset = i * pixel_attention_mask_stride + num_patches * h * w;
147+
const end_offset = (i + 1) * pixel_attention_mask_stride;
148+
pixel_attention_mask_data.fill(0, start_offset, end_offset);
149+
}
150+
}
151+
pixel_values = stack(all_pixel_values, 0);
152+
}
104153

105154
return {
106155
pixel_values,

src/models/idefics3/processing_idefics3.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ export class Idefics3Processor extends Processor {
6767
/**
6868
*
6969
* @param {string|string[]} text
70-
* @param {RawImage|RawImage[]} images
70+
* @param {RawImage|RawImage[]|RawImage[][]} images
7171
* @returns {Promise<any>}
7272
*/
7373
async _call(text, images = null, options = {}) {

0 commit comments

Comments
 (0)