33import {
44 ImageProcessor ,
55} from "../../base/image_processors_utils.js" ;
6- import { cat , full , interpolate_4d } from "../../utils/tensor.js" ;
6+ import { cat , full , interpolate_4d , stack } from "../../utils/tensor.js" ;
77
88export class Idefics3ImageProcessor extends ImageProcessor {
99 constructor ( config ) {
@@ -13,9 +13,14 @@ export class Idefics3ImageProcessor extends ImageProcessor {
1313 this . max_image_size = config . max_image_size ;
1414 }
1515
16+ /**
17+ * @typedef {import('../../utils/image.js').RawImage } RawImage
18+ * @typedef {import('../../utils/tensor.js').Tensor } Tensor
19+ */
20+
1621 /**
1722 * Calculate size to resize images to, to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
18- * @param {import('../../utils/tensor.js'). Tensor } pixel_values Tensor of the image to resize.
23+ * @param {Tensor } pixel_values Tensor of the image to resize.
1924 * @param {number } vision_encoder_max_size Maximum size of the output image. If the image is larger than this size,
2025 * it will be split into patches of this size, and the original image will be concatenated with the patches, resized to max_size.
2126 */
@@ -35,72 +40,116 @@ export class Idefics3ImageProcessor extends ImageProcessor {
3540 return { height, width } ;
3641 }
3742
38- // / ** @param {RawImage|RawImage[]|RawImage[][] } images */
43+ /** @param {RawImage|RawImage[]|RawImage[][] } images */
3944 async _call ( images , {
4045 do_image_splitting = null ,
4146 return_row_col_info = false ,
4247 } = { } ) {
43- // TODO: support 2D RawImages
48+
49+ /** @type {RawImage[][] } */
50+ let batched_2d_images ;
4451 if ( ! Array . isArray ( images ) ) {
45- images = [ images ] ;
52+ batched_2d_images = [ [ images ] ] ;
53+ } else {
54+ if ( images . length === 0 || ! images [ 0 ] ) {
55+ throw new Error ( "No images provided." ) ;
56+ }
57+ if ( ! Array . isArray ( images [ 0 ] ) ) {
58+ batched_2d_images = [ /** @type {RawImage[] } */ ( images ) ] ;
59+ } else {
60+ batched_2d_images = /** @type {RawImage[][] } */ ( images ) ;
61+ }
4662 }
4763
48- let images_list = await Promise . all ( images . map ( x => this . preprocess ( x ) ) ) ;
64+ // List of tensors, each with shape [patches, channels, height, width]
65+ let all_pixel_values = [ ] ;
66+ let images_list_rows = [ ] ;
67+ let images_list_cols = [ ] ;
4968
50- // Original sizes of images
51- const original_sizes = images_list . map ( x => x . original_size ) ;
69+ const original_sizes = [ ] ;
70+ const reshaped_input_sizes = [ ] ;
71+ for ( const image_batch of batched_2d_images ) {
5272
53- // Reshaped sizes of images, before padding or cropping
54- const reshaped_input_sizes = images_list . map ( x => x . reshaped_input_size ) ;
73+ let images_list = await Promise . all ( image_batch . map ( x => this . preprocess ( x ) ) ) ;
5574
56- // Convert images to 4D tensors for easier processing
57- images_list . forEach ( x => x . pixel_values . unsqueeze_ ( 0 ) ) ;
75+ // Original sizes of images
76+ original_sizes . push ( ... images_list . map ( x => x . original_size ) ) ;
5877
59- let pixel_values ;
60- let images_list_rows = [ ] ;
61- let images_list_cols = [ ] ;
78+ // Reshaped sizes of images, before padding or cropping
79+ reshaped_input_sizes . push ( ...images_list . map ( x => x . reshaped_input_size ) ) ;
6280
63- const { longest_edge } = this . max_image_size ;
81+ // Convert images to 4D tensors for easier processing
82+ images_list . forEach ( x => x . pixel_values . unsqueeze_ ( 0 ) ) ;
6483
65- if ( do_image_splitting ?? this . do_image_splitting ) {
66- let image_rows = new Array ( images_list . length ) ;
67- let image_cols = new Array ( images_list . length ) ;
84+ const { longest_edge } = this . max_image_size ;
6885
69- // We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
70- images_list = await Promise . all ( images_list . map ( async ( x , i ) => {
71- const new_size = this . get_resize_for_vision_encoder ( x . pixel_values , longest_edge ) ;
86+ /** @type {Tensor[] } */
87+ let images_tensor ;
88+ if ( do_image_splitting ?? this . do_image_splitting ) {
89+ let image_rows = new Array ( images_list . length ) ;
90+ let image_cols = new Array ( images_list . length ) ;
7291
73- const resized = await interpolate_4d ( x . pixel_values , {
74- size : [ new_size . height , new_size . width ] ,
75- } ) ;
92+ // We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
93+ images_tensor = await Promise . all ( images_list . map ( async ( x , i ) => {
94+ const new_size = this . get_resize_for_vision_encoder ( x . pixel_values , longest_edge ) ;
7695
77- const { frames, num_splits_h, num_splits_w } = await this . split_image ( resized , this . max_image_size ) ;
78- image_rows [ i ] = num_splits_h ;
79- image_cols [ i ] = num_splits_w ;
80- return cat ( frames , 0 ) ;
81- } ) ) ;
96+ const resized = await interpolate_4d ( x . pixel_values , {
97+ size : [ new_size . height , new_size . width ] ,
98+ } ) ;
8299
83- images_list_rows . push ( image_rows ) ;
84- images_list_cols . push ( image_cols ) ;
85- } else {
86- /** @type {[number, number] } */
87- const size = [ longest_edge , longest_edge ] ;
88- images_list = await Promise . all (
89- images_list . map ( x => interpolate_4d ( x . pixel_values , { size } ) )
90- ) ;
91-
92- images_list_rows . push ( new Array ( images_list . length ) . fill ( 0 ) ) ;
93- images_list_cols . push ( new Array ( images_list . length ) . fill ( 0 ) ) ;
100+ const { frames, num_splits_h, num_splits_w } = await this . split_image ( resized , this . max_image_size ) ;
101+ image_rows [ i ] = num_splits_h ;
102+ image_cols [ i ] = num_splits_w ;
103+ return cat ( frames , 0 ) ;
104+ } ) ) ;
105+
106+ images_list_rows . push ( image_rows ) ;
107+ images_list_cols . push ( image_cols ) ;
108+
109+ } else {
110+ /** @type {[number, number] } */
111+ const size = [ longest_edge , longest_edge ] ;
112+ images_tensor = await Promise . all (
113+ images_list . map ( x => interpolate_4d ( x . pixel_values , { size } ) )
114+ ) ;
115+
116+ images_list_rows . push ( new Array ( images_list . length ) . fill ( 0 ) ) ;
117+ images_list_cols . push ( new Array ( images_list . length ) . fill ( 0 ) ) ;
118+ }
119+
120+ all_pixel_values . push ( cat ( images_tensor , 0 ) ) ;
94121 }
95122
96123 // Stack pixel values
97- // TODO: support 2D images inputs
98- pixel_values = cat ( images_list , 0 ) ;
99- pixel_values . unsqueeze_ ( 0 ) ;
100-
101- // TODO: Improve pixel_attention_mask
102- const [ b , n , c , h , w ] = pixel_values . dims ;
103- const pixel_attention_mask = full ( [ b , n , h , w ] , true ) ;
124+ let pixel_values ;
125+ let pixel_attention_mask ;
126+ if ( all_pixel_values . length === 1 ) {
127+ pixel_values = all_pixel_values [ 0 ] ;
128+ pixel_values . unsqueeze_ ( 0 ) ;
129+ } else {
130+ // Add padding (if necessary) to images with less patches than the maximum number of patches
131+ const max_num_patches = Math . max ( ...all_pixel_values . map ( x => x . dims . at ( 0 ) ) ) ;
132+
133+ const [ c , h , w ] = all_pixel_values [ 0 ] . dims . slice ( 1 ) ;
134+
135+ pixel_attention_mask = full ( [ all_pixel_values . length , max_num_patches , h , w ] , 1 ) ;
136+ const pixel_attention_mask_data = pixel_attention_mask . data ;
137+ const pixel_attention_mask_stride = max_num_patches * h * w ;
138+ for ( let i = 0 ; i < all_pixel_values . length ; ++ i ) {
139+ const num_patches = all_pixel_values [ i ] . dims [ 0 ] ;
140+ if ( num_patches < max_num_patches ) {
141+ all_pixel_values [ i ] = cat ( [
142+ all_pixel_values [ i ] ,
143+ full ( [ max_num_patches - num_patches , c , h , w ] , 0 ) ,
144+ ] , 0 ) ;
145+
146+ const start_offset = i * pixel_attention_mask_stride + num_patches * h * w ;
147+ const end_offset = ( i + 1 ) * pixel_attention_mask_stride ;
148+ pixel_attention_mask_data . fill ( 0 , start_offset , end_offset ) ;
149+ }
150+ }
151+ pixel_values = stack ( all_pixel_values , 0 ) ;
152+ }
104153
105154 return {
106155 pixel_values,
0 commit comments