1616"""Example of building an image encoder of Qwen 2.5 VL model."""
1717
1818import dataclasses
19- from typing import Optional
19+ from typing import List , Optional , Tuple
2020
2121from ai_edge_torch .generative .layers import attention
2222from ai_edge_torch .generative .layers import attention_utils
@@ -93,7 +93,7 @@ def __init__(self, config: QwenVLImageConfig):
9393
9494 # Tensor shape used to reshape pixel_values in forward() and various places.
9595 self .kernel_size = (
96- - 1 , # batch size
96+ - 1 , # pixel_values. size(0)
9797 config .image_embedding .channels ,
9898 config .image_embedding .temporal_patch_size ,
9999 config .image_embedding .patch_size ,
@@ -118,28 +118,22 @@ def __init__(self, config: QwenVLImageConfig):
118118 )
119119 self .merger = QwenVLMerger (config )
120120 self .config = config
121+ self .set_image_size (config .image_embedding .image_size )
121122
122123 @torch .inference_mode
123- def forward (
124- self , pixel_values : torch .Tensor , grid_thw : torch .Tensor
125- ) -> torch .Tensor :
126- # Get window index and sequence lengths to rearrange the input tensor.
127- window_index , cu_seqlens = self ._get_window_index (grid_thw )
124+ def forward (self , pixel_values : torch .Tensor ) -> torch .Tensor :
125+ # Check if the pixel value size matches with grid size and image config.
126+ assert pixel_values .size () == self .get_pixel_values_size (self .grid_thw )
128127
129128 # Embed the image and rearrange the embedding tensor.
130- pixel_reshaped = pixel_values .view (self .kernel_size )
129+ pixel_reshaped = pixel_values .reshape (self .kernel_size )
131130 x = self .tok_embedding (pixel_reshaped )
132131 x = x .view (- 1 , self .config .embedding_dim )
133- x = self ._rearrange (x , window_index ).unsqueeze (0 )
132+ x = self ._rearrange (x , self . window_index ).unsqueeze (0 )
134133
135- # Get RoPE and attention mask arranged according to the window index.
136- cos , sin = self ._get_rope (grid_thw )
137- rope = (
138- self ._rearrange (cos , window_index ),
139- self ._rearrange (sin , window_index ),
140- )
134+ rope = self ._get_rope (self .grid_thw , self .window_index )
141135
142- mask = self ._get_mask (x . shape [ 1 ], cu_seqlens )
136+ mask = self ._get_mask (self . grid_thw , self . cu_seqlens )
143137 full_mask = torch .zeros (x .shape [:2 ])
144138 for i , block in enumerate (self .transformer_blocks ):
145139 x = block (
@@ -150,10 +144,42 @@ def forward(
150144
151145 y = self .merger .forward (self .final_norm (x ))
152146 # Arrange the output back to the original order.
153- reverse_index = torch .argsort (window_index )
154- return y [reverse_index , ...]
155-
156- def _get_rope (self , grid_thw : torch .Tensor ) -> torch .Tensor :
147+ return y [self .reverse_index , ...]
148+
149+ def set_image_size (self , image_size : Tuple [int , int ]):
150+ """Set the image size and pre-calculate some values including mask."""
151+ self .config .image_embedding .image_size = image_size
152+ self .grid_thw = self .get_grid_thw ()
153+
154+ # Precalculate the window index which can't be lowered to HLO because of
155+ # inconcrete index in:
156+ # index_new = index_padded[index_padded != -100]
157+ self .window_index , self .cu_seqlens = self ._get_window_index (self .grid_thw )
158+
159+ # Precalculate the reverse index of window_index until "vhlo.sort_v1" op is
160+ # supported.
161+ self .reverse_index = torch .argsort (self .window_index )
162+
163+ def get_grid_thw (self , num_images : int = 1 ) -> List [Tuple [int , int , int ]]:
164+ """Calculate the grid size of the input images based on the image config."""
165+ height , width = self .config .image_embedding .image_size
166+ patch_height = height // self .config .image_embedding .patch_size
167+ patch_width = width // self .config .image_embedding .patch_size
168+ # Support only image, i.e. temporal step size is always 1.
169+ return [(1 , patch_height , patch_width )] * num_images
170+
171+ def get_pixel_values_size (
172+ self , grid_thw : List [Tuple [int , int , int ]]
173+ ) -> torch .Size :
174+ """Calculate the size of pixel values tensor."""
175+ dim_0 = sum (t * h * w for t , h , w in grid_thw )
176+ config = self .config .image_embedding
177+ dim_1 = config .channels * config .temporal_patch_size * config .patch_size ** 2
178+ return torch .Size ((dim_0 , dim_1 ))
179+
180+ def _get_rope (
181+ self , grid_thw : List [Tuple [int , int , int ]], window_index : torch .Tensor
182+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
157183 """Get RoPE for Qwen VL model based on image grid information.
158184
159185 It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
@@ -182,16 +208,20 @@ def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor:
182208 wpos_ids = wpos_ids .flatten ()
183209 pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
184210 pos_ids = torch .cat (pos_ids , dim = 0 )
185- max_grid_size = grid_thw [:, 1 :].max ()
211+ # Assume all the heights and widths are the same for all images.
212+ max_grid_size = max (grid_thw [0 ][1 ], grid_thw [0 ][2 ])
186213
187214 cos , sin = attention_utils .build_rope_cache (
188215 max_grid_size ,
189216 # ROPE parameters for all attn_configs are the same. Take the first one.
190217 self .config .block_config (0 ).attn_config .head_dim // 2 ,
191218 )
192- return cos [pos_ids ].flatten (1 ), sin [pos_ids ].flatten (1 )
219+ return (
220+ self ._rearrange (cos [pos_ids ].flatten (1 ), window_index ),
221+ self ._rearrange (sin [pos_ids ].flatten (1 ), window_index ),
222+ )
193223
194- def _get_window_index (self , grid_thw : torch . Tensor ):
224+ def _get_window_index (self , grid_thw : List [ Tuple [ int , int , int ]] ):
195225 """Get window index for Qwen VL model to rearrange the input tensor.
196226
197227 It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
@@ -207,13 +237,10 @@ def _get_window_index(self, grid_thw: torch.Tensor):
207237 )
208238
209239 for grid_t , grid_h , grid_w in grid_thw :
210- llm_grid_h , llm_grid_w = (
211- grid_h // self .config .spatial_merge_size ,
212- grid_w // self .config .spatial_merge_size ,
213- )
214- index = torch .arange (grid_t * llm_grid_h * llm_grid_w ).reshape (
215- grid_t , llm_grid_h , llm_grid_w
216- )
240+ llm_grid_h = grid_h // self .config .spatial_merge_size
241+ llm_grid_w = grid_w // self .config .spatial_merge_size
242+ index = torch .arange (grid_t * llm_grid_h * llm_grid_w )
243+ index = index .reshape ((grid_t , llm_grid_h , llm_grid_w ))
217244 pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
218245 pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
219246 num_windows_h = (llm_grid_h + pad_h ) // vit_merger_window_size
@@ -236,18 +263,14 @@ def _get_window_index(self, grid_thw: torch.Tensor):
236263 index_padded = index_padded .reshape (- 1 )
237264 index_new = index_padded [index_padded != - 100 ]
238265 window_index .append (index_new + window_index_id )
239- spatial_merge_unit = (
240- self .config .spatial_merge_size * self .config .spatial_merge_size
241- )
266+ spatial_merge_unit = self .config .spatial_merge_size ** 2
242267 cu_seqlens_tmp = (
243268 seqlens .cumsum (0 ) * spatial_merge_unit + cu_window_seqlens [- 1 ]
244269 )
245270 cu_window_seqlens .extend (cu_seqlens_tmp .tolist ())
246- window_index_id += ( grid_t * llm_grid_h * llm_grid_w ). item ()
271+ window_index_id += grid_t * llm_grid_h * llm_grid_w
247272
248273 window_index = torch .cat (window_index , dim = 0 )
249- cu_window_seqlens = torch .tensor (cu_window_seqlens )
250- cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
251274 return window_index , cu_window_seqlens
252275
253276 def _rearrange (
@@ -258,20 +281,20 @@ def _rearrange(
258281 It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
259282 modified accordingly.
260283 """
261- size = x .shape [0 ]
262- spatial_merge_unit = (
263- self .config .spatial_merge_size * self .config .spatial_merge_size
264- )
265- x_reshaped = x .view (size // spatial_merge_unit , spatial_merge_unit , - 1 )
284+ spatial_merge_unit = self .config .spatial_merge_size ** 2
285+ x_reshaped = x .view (x .size (0 ) // spatial_merge_unit , spatial_merge_unit , - 1 )
266286 x_rearranged = x_reshaped [window_index , ...]
267- return x_rearranged .view (size , - 1 )
287+ return x_rearranged .view (x . shape )
268288
269- def _get_mask (self , seqlen : int , cu_seqlens : torch .Tensor ) -> torch .Tensor :
289+ def _get_mask (
290+ self , grid_thw : List [Tuple [int , int , int ]], cu_seqlens : List [int ]
291+ ) -> torch .Tensor :
270292 """Get attention mask for Qwen VL model.
271293
272294 It's copied from Qwen2_5_VLVisionAttention.forward() and modified
273295 accordingly.
274296 """
297+ seqlen = self .get_pixel_values_size (grid_thw )[0 ]
275298 mask = torch .full ([1 , 1 , seqlen , seqlen ], float ("-inf" ))
276299 for i in range (1 , len (cu_seqlens )):
277300 mask [
@@ -282,15 +305,15 @@ def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
282305 return mask
283306
284307
285- def get_image_encoder_config () -> QwenVLImageConfig :
308+ def get_image_encoder_config (image_size : Tuple [ int , int ] ) -> QwenVLImageConfig :
286309 """Returns the model config for the image encoder of a Qwen 2.5 VL model.
287310
288311 Returns:
289312 The model config for the image encoder of a Qwen 2.5 VL model.
290313 """
291314 image_embedding_config = cfg .ImageEmbeddingConfig (
292315 channels = 3 ,
293- image_size = 0 , # Not used in image encoder.
316+ image_size = image_size ,
294317 patch_size = 14 ,
295318 temporal_patch_size = 2 ,
296319 )
@@ -336,15 +359,13 @@ def get_image_encoder_config() -> QwenVLImageConfig:
336359 window_size = 112 ,
337360 spatial_merge_size = 2 ,
338361 full_atten_block_indexes = [7 , 15 , 23 , 31 ],
339- # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
340- # enable_hlfb can be set to True. See b/383865404#comment3 for details.
341- # enable_hlfb=True,
362+ enable_hlfb = True ,
342363 )
343364 return config
344365
345366
346367def get_fake_image_encoder_config () -> QwenVLImageConfig :
347- config = get_image_encoder_config ()
368+ config = get_image_encoder_config (( 8 , 12 ) )
348369 # PaliGemma image encoder has only one block config.
349370 config .block_config (0 ).ff_config .intermediate_size = 128
350371 config .image_embedding .patch_size = 2
@@ -353,8 +374,11 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
353374 return config
354375
355376
356- def build_image_encoder (checkpoint_path : str ) -> QwenVLImageEncoder :
357- config = get_image_encoder_config ()
377+ def build_image_encoder (
378+ checkpoint_path : str ,
379+ image_size : Tuple [int , int ] = (34 * 14 , 46 * 14 ),
380+ ) -> QwenVLImageEncoder :
381+ config = get_image_encoder_config (image_size )
358382 encoder = QwenVLImageEncoder (config )
359383 load_image_encoder (checkpoint_path , encoder )
360384 encoder .eval ()
0 commit comments