|
| 1 | +# Copyright 2025 The AI Edge Torch Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +"""Example of building an image encoder of Qwen 2.5 VL model.""" |
| 17 | + |
| 18 | +import dataclasses |
| 19 | +from typing import Optional |
| 20 | + |
| 21 | +from ai_edge_torch.generative.layers import attention |
| 22 | +from ai_edge_torch.generative.layers import attention_utils |
| 23 | +from ai_edge_torch.generative.layers import builder |
| 24 | +import ai_edge_torch.generative.layers.model_config as cfg |
| 25 | +import ai_edge_torch.generative.utilities.loader as loading_utils |
| 26 | +import torch |
| 27 | +from torch import nn |
| 28 | +import torch.nn.functional as F |
| 29 | + |
| 30 | +TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( |
| 31 | + ff_up_proj="visual.blocks.{}.mlp.up_proj", |
| 32 | + ff_down_proj="visual.blocks.{}.mlp.down_proj", |
| 33 | + ff_gate_proj="visual.blocks.{}.mlp.gate_proj", |
| 34 | + attn_fused_qkv_proj="visual.blocks.{}.attn.qkv", |
| 35 | + attn_output_proj="visual.blocks.{}.attn.proj", |
| 36 | + pre_attn_norm="visual.blocks.{}.norm1", |
| 37 | + post_attn_norm="visual.blocks.{}.norm2", |
| 38 | + embedding="visual.patch_embed.proj", |
| 39 | + final_norm="visual.merger.ln_q", |
| 40 | +) |
| 41 | + |
| 42 | +MERGER_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( |
| 43 | + ff_up_proj="visual.merger.mlp.0", |
| 44 | + ff_down_proj="visual.merger.mlp.2", |
| 45 | +) |
| 46 | + |
| 47 | + |
| 48 | +@dataclasses.dataclass |
| 49 | +class QwenVLMergerConfig: |
| 50 | + """Merger parameters.""" |
| 51 | + |
| 52 | + activation: cfg.ActivationConfig |
| 53 | + intermediate_size: int |
| 54 | + out_embedding_dim: int |
| 55 | + use_bias: bool = False |
| 56 | + |
| 57 | + |
| 58 | +@dataclasses.dataclass |
| 59 | +class QwenVLImageConfig(cfg.ModelConfig): |
| 60 | + """model config for Qwen 2.5 VL model.""" |
| 61 | + |
| 62 | + merger_config: Optional[QwenVLMergerConfig] = None |
| 63 | + window_size: Optional[int] = None |
| 64 | + spatial_merge_size: Optional[int] = None |
| 65 | + full_atten_block_indexes: Optional[list[int]] = None |
| 66 | + |
| 67 | + |
| 68 | +class QwenVLMerger(nn.Module): |
| 69 | + """Merger of Qwen 2.5 VL models from the Edge Generative API. |
| 70 | +
|
| 71 | + It's based on Qwen2_5_VLPatchMerger. |
| 72 | + """ |
| 73 | + |
| 74 | + def __init__(self, config: QwenVLImageConfig): |
| 75 | + super().__init__() |
| 76 | + self.intermediate_size = config.merger_config.intermediate_size |
| 77 | + self.w1 = nn.Linear(self.intermediate_size, self.intermediate_size) |
| 78 | + self.act = builder.get_activation(config.merger_config.activation) |
| 79 | + self.w2 = nn.Linear( |
| 80 | + self.intermediate_size, config.merger_config.out_embedding_dim |
| 81 | + ) |
| 82 | + |
| 83 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 84 | + x_reshaped = x.view(-1, self.intermediate_size) |
| 85 | + return self.w2(self.act(self.w1(x_reshaped))) |
| 86 | + |
| 87 | + |
| 88 | +class QwenVLImageEncoder(nn.Module): |
| 89 | + """Image encoder of Qwen 2.5 VL models from the Edge Generative API.""" |
| 90 | + |
| 91 | + def __init__(self, config: QwenVLImageConfig): |
| 92 | + super().__init__() |
| 93 | + |
| 94 | + # Tensor shape used to reshape pixel_values in forward() and various places. |
| 95 | + self.kernel_size = ( |
| 96 | + -1, # batch size |
| 97 | + config.image_embedding.channels, |
| 98 | + config.image_embedding.temporal_patch_size, |
| 99 | + config.image_embedding.patch_size, |
| 100 | + config.image_embedding.patch_size, |
| 101 | + ) |
| 102 | + self.tok_embedding = nn.Conv3d( |
| 103 | + in_channels=self.kernel_size[1], |
| 104 | + out_channels=config.embedding_dim, |
| 105 | + kernel_size=self.kernel_size[2:], |
| 106 | + stride=self.kernel_size[2:], |
| 107 | + padding=0, |
| 108 | + bias=config.embedding_use_bias, |
| 109 | + ) |
| 110 | + |
| 111 | + self.transformer_blocks = nn.ModuleList( |
| 112 | + attention.TransformerBlock(config.block_config(idx), config) |
| 113 | + for idx in range(config.num_layers) |
| 114 | + ) |
| 115 | + self.final_norm = builder.build_norm( |
| 116 | + config.embedding_dim, |
| 117 | + config.final_norm_config, |
| 118 | + ) |
| 119 | + self.merger = QwenVLMerger(config) |
| 120 | + self.config = config |
| 121 | + |
| 122 | + @torch.inference_mode |
| 123 | + def forward( |
| 124 | + self, pixel_values: torch.Tensor, grid_thw: torch.Tensor |
| 125 | + ) -> torch.Tensor: |
| 126 | + # Get window index and sequence lengths to rearrange the input tensor. |
| 127 | + window_index, cu_seqlens = self._get_window_index(grid_thw) |
| 128 | + |
| 129 | + # Embed the image and rearrange the embedding tensor. |
| 130 | + pixel_reshaped = pixel_values.view(self.kernel_size) |
| 131 | + x = self.tok_embedding(pixel_reshaped) |
| 132 | + x = x.view(-1, self.config.embedding_dim) |
| 133 | + x = self._rearrange(x, window_index).unsqueeze(0) |
| 134 | + |
| 135 | + # Get RoPE and attention mask arranged according to the window index. |
| 136 | + cos, sin = self._get_rope(grid_thw) |
| 137 | + rope = ( |
| 138 | + self._rearrange(cos, window_index), |
| 139 | + self._rearrange(sin, window_index), |
| 140 | + ) |
| 141 | + |
| 142 | + mask = self._get_mask(x.shape[1], cu_seqlens) |
| 143 | + full_mask = torch.zeros(x.shape[:2]) |
| 144 | + for i, block in enumerate(self.transformer_blocks): |
| 145 | + x = block( |
| 146 | + x, |
| 147 | + rope=rope, |
| 148 | + mask=full_mask if i in self.config.full_atten_block_indexes else mask, |
| 149 | + ) |
| 150 | + |
| 151 | + y = self.merger.forward(self.final_norm(x)) |
| 152 | + # Arrange the output back to the original order. |
| 153 | + reverse_index = torch.argsort(window_index) |
| 154 | + return y[reverse_index, ...] |
| 155 | + |
| 156 | + def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor: |
| 157 | + """Get RoPE for Qwen VL model based on image grid information. |
| 158 | +
|
| 159 | + It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and |
| 160 | + modified accordingly. |
| 161 | + """ |
| 162 | + pos_ids = [] |
| 163 | + for t, h, w in grid_thw: |
| 164 | + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
| 165 | + hpos_ids = hpos_ids.reshape( |
| 166 | + h // self.config.spatial_merge_size, |
| 167 | + self.config.spatial_merge_size, |
| 168 | + w // self.config.spatial_merge_size, |
| 169 | + self.config.spatial_merge_size, |
| 170 | + ) |
| 171 | + hpos_ids = hpos_ids.permute(0, 2, 1, 3) |
| 172 | + hpos_ids = hpos_ids.flatten() |
| 173 | + |
| 174 | + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
| 175 | + wpos_ids = wpos_ids.reshape( |
| 176 | + h // self.config.spatial_merge_size, |
| 177 | + self.config.spatial_merge_size, |
| 178 | + w // self.config.spatial_merge_size, |
| 179 | + self.config.spatial_merge_size, |
| 180 | + ) |
| 181 | + wpos_ids = wpos_ids.permute(0, 2, 1, 3) |
| 182 | + wpos_ids = wpos_ids.flatten() |
| 183 | + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
| 184 | + pos_ids = torch.cat(pos_ids, dim=0) |
| 185 | + max_grid_size = grid_thw[:, 1:].max() |
| 186 | + |
| 187 | + cos, sin = attention_utils.build_rope_cache( |
| 188 | + max_grid_size, |
| 189 | + # ROPE parameters for all attn_configs are the same. Take the first one. |
| 190 | + self.config.block_config(0).attn_config.head_dim // 2, |
| 191 | + ) |
| 192 | + return cos[pos_ids].flatten(1), sin[pos_ids].flatten(1) |
| 193 | + |
| 194 | + def _get_window_index(self, grid_thw: torch.Tensor): |
| 195 | + """Get window index for Qwen VL model to rearrange the input tensor. |
| 196 | +
|
| 197 | + It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index() |
| 198 | + and modified accordingly. |
| 199 | + """ |
| 200 | + window_index: list = [] |
| 201 | + cu_window_seqlens: list = [0] |
| 202 | + window_index_id = 0 |
| 203 | + vit_merger_window_size = ( |
| 204 | + self.config.window_size |
| 205 | + // self.config.spatial_merge_size |
| 206 | + // self.config.image_embedding.patch_size |
| 207 | + ) |
| 208 | + |
| 209 | + for grid_t, grid_h, grid_w in grid_thw: |
| 210 | + llm_grid_h, llm_grid_w = ( |
| 211 | + grid_h // self.config.spatial_merge_size, |
| 212 | + grid_w // self.config.spatial_merge_size, |
| 213 | + ) |
| 214 | + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( |
| 215 | + grid_t, llm_grid_h, llm_grid_w |
| 216 | + ) |
| 217 | + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size |
| 218 | + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size |
| 219 | + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size |
| 220 | + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size |
| 221 | + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) |
| 222 | + index_padded = index_padded.reshape( |
| 223 | + grid_t, |
| 224 | + num_windows_h, |
| 225 | + vit_merger_window_size, |
| 226 | + num_windows_w, |
| 227 | + vit_merger_window_size, |
| 228 | + ) |
| 229 | + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( |
| 230 | + grid_t, |
| 231 | + num_windows_h * num_windows_w, |
| 232 | + vit_merger_window_size, |
| 233 | + vit_merger_window_size, |
| 234 | + ) |
| 235 | + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) |
| 236 | + index_padded = index_padded.reshape(-1) |
| 237 | + index_new = index_padded[index_padded != -100] |
| 238 | + window_index.append(index_new + window_index_id) |
| 239 | + spatial_merge_unit = ( |
| 240 | + self.config.spatial_merge_size * self.config.spatial_merge_size |
| 241 | + ) |
| 242 | + cu_seqlens_tmp = ( |
| 243 | + seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1] |
| 244 | + ) |
| 245 | + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) |
| 246 | + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() |
| 247 | + |
| 248 | + window_index = torch.cat(window_index, dim=0) |
| 249 | + cu_window_seqlens = torch.tensor(cu_window_seqlens) |
| 250 | + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
| 251 | + return window_index, cu_window_seqlens |
| 252 | + |
| 253 | + def _rearrange( |
| 254 | + self, x: torch.Tensor, window_index: torch.Tensor |
| 255 | + ) -> torch.Tensor: |
| 256 | + """Rearrange the tensor according to window_index. |
| 257 | +
|
| 258 | + It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and |
| 259 | + modified accordingly. |
| 260 | + """ |
| 261 | + size = x.shape[0] |
| 262 | + spatial_merge_unit = ( |
| 263 | + self.config.spatial_merge_size * self.config.spatial_merge_size |
| 264 | + ) |
| 265 | + x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1) |
| 266 | + x_rearranged = x_reshaped[window_index, ...] |
| 267 | + return x_rearranged.view(size, -1) |
| 268 | + |
| 269 | + def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: |
| 270 | + """Get attention mask for Qwen VL model. |
| 271 | +
|
| 272 | + It's copied from Qwen2_5_VLVisionAttention.forward() and modified |
| 273 | + accordingly. |
| 274 | + """ |
| 275 | + mask = torch.full([1, 1, seqlen, seqlen], float("-inf")) |
| 276 | + for i in range(1, len(cu_seqlens)): |
| 277 | + mask[ |
| 278 | + ..., |
| 279 | + cu_seqlens[i - 1] : cu_seqlens[i], |
| 280 | + cu_seqlens[i - 1] : cu_seqlens[i], |
| 281 | + ] = 0 |
| 282 | + return mask |
| 283 | + |
| 284 | + |
| 285 | +def get_image_encoder_config() -> QwenVLImageConfig: |
| 286 | + """Returns the model config for the image encoder of a Qwen 2.5 VL model. |
| 287 | +
|
| 288 | + Returns: |
| 289 | + The model config for the image encoder of a Qwen 2.5 VL model. |
| 290 | + """ |
| 291 | + image_embedding_config = cfg.ImageEmbeddingConfig( |
| 292 | + channels=3, |
| 293 | + image_size=0, # Not used in image encoder. |
| 294 | + patch_size=14, |
| 295 | + temporal_patch_size=2, |
| 296 | + ) |
| 297 | + attn_config = cfg.AttentionConfig( |
| 298 | + num_heads=16, |
| 299 | + head_dim=80, |
| 300 | + num_query_groups=16, |
| 301 | + qkv_transpose_before_split=True, |
| 302 | + qkv_use_bias=True, |
| 303 | + output_proj_use_bias=True, |
| 304 | + ) |
| 305 | + ff_config = cfg.FeedForwardConfig( |
| 306 | + type=cfg.FeedForwardType.GATED, |
| 307 | + activation=cfg.ActivationConfig(cfg.ActivationType.SILU), |
| 308 | + intermediate_size=3420, |
| 309 | + use_bias=True, |
| 310 | + ) |
| 311 | + norm_config = cfg.NormalizationConfig( |
| 312 | + type=cfg.NormalizationType.RMS_NORM, |
| 313 | + epsilon=1e-6, |
| 314 | + ) |
| 315 | + block_config = cfg.TransformerBlockConfig( |
| 316 | + attn_config=attn_config, |
| 317 | + ff_config=ff_config, |
| 318 | + pre_attention_norm_config=norm_config, |
| 319 | + post_attention_norm_config=norm_config, |
| 320 | + ) |
| 321 | + merger_config = QwenVLMergerConfig( |
| 322 | + activation=cfg.ActivationConfig(cfg.ActivationType.GELU), |
| 323 | + intermediate_size=5120, # embedding_dim(1280) * spatial_merge_size(2)^2 |
| 324 | + out_embedding_dim=2048, # embedding_dim of decoder config. |
| 325 | + use_bias=True, |
| 326 | + ) |
| 327 | + config = QwenVLImageConfig( |
| 328 | + vocab_size=0, # Not used in image encoder. |
| 329 | + num_layers=32, |
| 330 | + max_seq_len=0, # Not used in image encoder. |
| 331 | + embedding_dim=1280, |
| 332 | + image_embedding=image_embedding_config, |
| 333 | + block_configs=block_config, |
| 334 | + final_norm_config=norm_config, |
| 335 | + merger_config=merger_config, |
| 336 | + window_size=112, |
| 337 | + spatial_merge_size=2, |
| 338 | + full_atten_block_indexes=[7, 15, 23, 31], |
| 339 | + # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed, |
| 340 | + # enable_hlfb can be set to True. See b/383865404#comment3 for details. |
| 341 | + # enable_hlfb=True, |
| 342 | + ) |
| 343 | + return config |
| 344 | + |
| 345 | + |
| 346 | +def get_fake_image_encoder_config() -> QwenVLImageConfig: |
| 347 | + config = get_image_encoder_config() |
| 348 | + # PaliGemma image encoder has only one block config. |
| 349 | + config.block_config(0).ff_config.intermediate_size = 128 |
| 350 | + config.image_embedding.patch_size = 2 |
| 351 | + config.num_layers = 2 |
| 352 | + config.merger_config.intermediate_size = 128 |
| 353 | + return config |
| 354 | + |
| 355 | + |
| 356 | +def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder: |
| 357 | + config = get_image_encoder_config() |
| 358 | + encoder = QwenVLImageEncoder(config) |
| 359 | + loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) |
| 360 | + # Loose the strictness because only image encoder is being loaded. |
| 361 | + loader.load(encoder, strict=False) |
| 362 | + |
| 363 | + # Load merger weights. |
| 364 | + merger_loader = loading_utils.ModelLoader(checkpoint_path, None) |
| 365 | + state = merger_loader.get_state() |
| 366 | + w1_state = dict() |
| 367 | + w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight") |
| 368 | + if config.merger_config.use_bias: |
| 369 | + w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias") |
| 370 | + encoder.merger.w1.load_state_dict(w1_state) |
| 371 | + |
| 372 | + w2_state = dict() |
| 373 | + w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight") |
| 374 | + if config.merger_config.use_bias: |
| 375 | + w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias") |
| 376 | + encoder.merger.w2.load_state_dict(w2_state) |
| 377 | + |
| 378 | + encoder.eval() |
| 379 | + return encoder |
0 commit comments