|  | 
|  | 1 | +# Copyright 2025 The AI Edge Torch Authors. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +# ============================================================================== | 
|  | 15 | + | 
|  | 16 | +"""Example of building a full-stack of Qwen 2.5 VL model.""" | 
|  | 17 | + | 
|  | 18 | +import dataclasses | 
|  | 19 | +from typing import List, Optional, Tuple | 
|  | 20 | + | 
|  | 21 | +from ai_edge_torch.generative.examples.qwen_vl import decoder | 
|  | 22 | +from ai_edge_torch.generative.examples.qwen_vl import image_encoder | 
|  | 23 | +import ai_edge_torch.generative.layers.kv_cache as kv_utils | 
|  | 24 | +import ai_edge_torch.generative.layers.model_config as cfg | 
|  | 25 | +from ai_edge_torch.generative.utilities import model_builder | 
|  | 26 | +import ai_edge_torch.generative.utilities.loader as loading_utils | 
|  | 27 | +import torch | 
|  | 28 | +from torch import nn | 
|  | 29 | + | 
|  | 30 | + | 
|  | 31 | +@dataclasses.dataclass | 
|  | 32 | +class QwenVLConfig: | 
|  | 33 | +  """Qwen VL model configurations.""" | 
|  | 34 | + | 
|  | 35 | +  image_encoder_config: image_encoder.QwenVLImageConfig | 
|  | 36 | +  decoder_config: cfg.ModelConfig | 
|  | 37 | +  image_token_id: int | 
|  | 38 | +  mrope_section: List[int] | 
|  | 39 | + | 
|  | 40 | + | 
|  | 41 | +class QwenVL(nn.Module): | 
|  | 42 | +  """Qwen VL model from the Edge Generative API.""" | 
|  | 43 | + | 
|  | 44 | +  def __init__(self, config: QwenVLConfig): | 
|  | 45 | +    super().__init__() | 
|  | 46 | + | 
|  | 47 | +    self.image_encoder = image_encoder.QwenVLImageEncoder( | 
|  | 48 | +        config.image_encoder_config | 
|  | 49 | +    ) | 
|  | 50 | +    self.decoder = decoder.Decoder(config.decoder_config) | 
|  | 51 | +    # The amount of adjustment in input_pos to calculate RoPE properly in | 
|  | 52 | +    # forward() calls after image is handled. | 
|  | 53 | +    self.rope_pos_adjust = 0 | 
|  | 54 | +    self.config = config | 
|  | 55 | + | 
|  | 56 | +  @torch.inference_mode | 
|  | 57 | +  def forward( | 
|  | 58 | +      self, | 
|  | 59 | +      tokens: torch.Tensor, | 
|  | 60 | +      input_pos: torch.Tensor, | 
|  | 61 | +      kv_cache: kv_utils.KVCache, | 
|  | 62 | +      mask: Optional[torch.Tensor] = None, | 
|  | 63 | +      pixel_values: torch.Tensor = None, | 
|  | 64 | +      grid_thw: torch.Tensor = None, | 
|  | 65 | +      export_config: Optional[model_builder.ExportConfig] = None, | 
|  | 66 | +  ) -> dict[torch.Tensor, kv_utils.KVCache]: | 
|  | 67 | +    if pixel_values is None: | 
|  | 68 | +      return self.decoder( | 
|  | 69 | +          tokens=tokens, | 
|  | 70 | +          input_pos=input_pos, | 
|  | 71 | +          kv_cache=kv_cache, | 
|  | 72 | +          mask=mask, | 
|  | 73 | +          rope=self._build_text_rope(input_pos), | 
|  | 74 | +          input_embeds=None, | 
|  | 75 | +          export_config=export_config, | 
|  | 76 | +      ) | 
|  | 77 | + | 
|  | 78 | +    input_embeds = self.decoder.tok_embedding(tokens) | 
|  | 79 | +    image_embeds = self.image_encoder(pixel_values, grid_thw).unsqueeze(0) | 
|  | 80 | + | 
|  | 81 | +    # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration | 
|  | 82 | +    # can be done like: | 
|  | 83 | +    # | 
|  | 84 | +    #   image_mask = tokens == self.config.image_token_id | 
|  | 85 | +    #   image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds) | 
|  | 86 | +    #   input_embeds = input_embeds.masked_scatter(image_mask, image_embeds) | 
|  | 87 | +    # | 
|  | 88 | +    # Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU. | 
|  | 89 | +    # Assume that image is put at the beginning of the input sequence wrapped | 
|  | 90 | +    # with vision_start and vision_end tokens. | 
|  | 91 | +    input_embeds = torch.cat( | 
|  | 92 | +        ( | 
|  | 93 | +            input_embeds[:, :1, :], | 
|  | 94 | +            image_embeds, | 
|  | 95 | +            input_embeds[:, image_embeds.shape[1] + 1 :, :], | 
|  | 96 | +        ), | 
|  | 97 | +        dim=1, | 
|  | 98 | +    ) | 
|  | 99 | + | 
|  | 100 | +    return self.decoder( | 
|  | 101 | +        tokens=None, | 
|  | 102 | +        input_pos=input_pos, | 
|  | 103 | +        kv_cache=kv_cache, | 
|  | 104 | +        mask=mask, | 
|  | 105 | +        input_embeds=input_embeds, | 
|  | 106 | +        rope=self._build_multimodal_rope(input_pos, grid_thw), | 
|  | 107 | +        export_config=export_config, | 
|  | 108 | +    ) | 
|  | 109 | + | 
|  | 110 | +  def _build_rope( | 
|  | 111 | +      self, rope_pos: torch.Tensor | 
|  | 112 | +  ) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | 113 | +    # ROPE parameters for all attn_configs are the same. Take the first one. | 
|  | 114 | +    attn_config = self.config.decoder_config.block_config(0).attn_config | 
|  | 115 | +    n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) | 
|  | 116 | +    return self.config.decoder_config.build_rope( | 
|  | 117 | +        rope_pos, n_elem, attn_config.rotary_base | 
|  | 118 | +    ) | 
|  | 119 | + | 
|  | 120 | +  def _build_text_rope( | 
|  | 121 | +      self, input_pos: torch.Tensor | 
|  | 122 | +  ) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | 123 | +    # Reset rope_pos_adjust to 0 when input sequence starts from scratch, i.e. | 
|  | 124 | +    # input_pos[0] = 0. | 
|  | 125 | +    if input_pos[0] == 0: | 
|  | 126 | +      self.rope_pos_adjust = 0 | 
|  | 127 | +    return self._build_rope(input_pos + self.rope_pos_adjust) | 
|  | 128 | + | 
|  | 129 | +  def _build_multimodal_rope( | 
|  | 130 | +      self, input_pos: torch.Tensor, grid_thw: torch.Tensor | 
|  | 131 | +  ) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | 132 | +    """Builds RoPE of multimodal input for the Qwen VL model. | 
|  | 133 | +
 | 
|  | 134 | +    It's copied from Qwen2_5_VLForConditionalGeneration.get_rope_index() and | 
|  | 135 | +    simplified based on the assumption that an image is put at the beginning of | 
|  | 136 | +    the input sequence with vision start and vision end tokens. | 
|  | 137 | +    """ | 
|  | 138 | +    spatial_merge_size = self.config.image_encoder_config.spatial_merge_size | 
|  | 139 | +    height = grid_thw[0][1] // spatial_merge_size | 
|  | 140 | +    width = grid_thw[0][2] // spatial_merge_size | 
|  | 141 | +    image_pos_max = max(height, width) | 
|  | 142 | +    image_pos_count = height * width | 
|  | 143 | + | 
|  | 144 | +    # The position of vision end tokek and text tokens and after the image. | 
|  | 145 | +    text_pos_start = image_pos_max + 1 | 
|  | 146 | +    text_pos_count = len(input_pos) - image_pos_count - 1 | 
|  | 147 | +    text_pos = torch.arange(text_pos_start, text_pos_start + text_pos_count) | 
|  | 148 | +    # Set input_pos_adjust since text_pos_start has changed. | 
|  | 149 | +    self.rope_pos_adjust = image_pos_max - image_pos_count | 
|  | 150 | + | 
|  | 151 | +    temporal_rope = self._build_image_text_rope( | 
|  | 152 | +        torch.ones(image_pos_count, dtype=torch.int), text_pos | 
|  | 153 | +    ) | 
|  | 154 | +    height_rope = self._build_image_text_rope( | 
|  | 155 | +        torch.arange(1, height + 1).view(-1, 1).expand(-1, width).flatten(), | 
|  | 156 | +        text_pos, | 
|  | 157 | +    ) | 
|  | 158 | +    width_rope = self._build_image_text_rope( | 
|  | 159 | +        torch.arange(1, width + 1).view(1, -1).expand(height, -1).flatten(), | 
|  | 160 | +        text_pos, | 
|  | 161 | +    ) | 
|  | 162 | + | 
|  | 163 | +    return ( | 
|  | 164 | +        self._merge_ropes(temporal_rope[0], height_rope[0], width_rope[0]), | 
|  | 165 | +        self._merge_ropes(temporal_rope[1], height_rope[1], width_rope[1]), | 
|  | 166 | +    ) | 
|  | 167 | + | 
|  | 168 | +  def _build_image_text_rope( | 
|  | 169 | +      self, image_pos: torch.Tensor, text_pos: torch.Tensor | 
|  | 170 | +  ) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | 171 | +    return self._build_rope( | 
|  | 172 | +        torch.cat((torch.zeros(1, dtype=torch.int), image_pos, text_pos)) | 
|  | 173 | +    ) | 
|  | 174 | + | 
|  | 175 | +  def _merge_ropes(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): | 
|  | 176 | +    """Merges RoPE tensors based on apply_multimodal_rotary_pos_emb().""" | 
|  | 177 | +    split = torch.stack([a, b, c]).split(self.config.mrope_section, dim=-1) | 
|  | 178 | +    return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1) | 
|  | 179 | + | 
|  | 180 | + | 
|  | 181 | +def get_model_config(**kwargs) -> QwenVLConfig: | 
|  | 182 | +  """Returns the model config for a PaliGemma 3B-224 model. | 
|  | 183 | +
 | 
|  | 184 | +  Returns: | 
|  | 185 | +    The model config for a PaliGemma 3B model. | 
|  | 186 | +  """ | 
|  | 187 | +  return QwenVLConfig( | 
|  | 188 | +      image_encoder_config=image_encoder.get_image_encoder_config(), | 
|  | 189 | +      decoder_config=decoder.get_decoder_config(**kwargs), | 
|  | 190 | +      image_token_id=151655, | 
|  | 191 | +      mrope_section=[16, 24, 24], | 
|  | 192 | +  ) | 
|  | 193 | + | 
|  | 194 | + | 
|  | 195 | +def get_fake_model_config(**kwargs) -> QwenVLConfig: | 
|  | 196 | +  return QwenVLConfig( | 
|  | 197 | +      image_encoder_config=image_encoder.get_fake_image_encoder_config(), | 
|  | 198 | +      decoder_config=decoder.get_fake_decoder_config(**kwargs), | 
|  | 199 | +      image_token_id=127, | 
|  | 200 | +  ) | 
|  | 201 | + | 
|  | 202 | + | 
|  | 203 | +def build_model(checkpoint_path: str, **kwargs) -> QwenVL: | 
|  | 204 | +  config = get_model_config(**kwargs) | 
|  | 205 | +  model = QwenVL(config) | 
|  | 206 | +  image_encoder.load_image_encoder(checkpoint_path, model.image_encoder) | 
|  | 207 | +  # Load the parameters of decoder. | 
|  | 208 | +  loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES) | 
|  | 209 | +  loader.load(model.decoder, strict=False) | 
|  | 210 | +  model.eval() | 
|  | 211 | +  return model | 
0 commit comments