|  | 
|  | 1 | +# Copyright 2024 The AI Edge Torch Authors. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +# ============================================================================== | 
|  | 15 | + | 
|  | 16 | +"""Example of building a decoder of PaliGemma2 3B model which is Gemma2.""" | 
|  | 17 | + | 
|  | 18 | +from typing import Optional | 
|  | 19 | + | 
|  | 20 | +from ai_edge_torch.generative.examples.gemma import gemma2 | 
|  | 21 | +from ai_edge_torch.generative.layers import kv_cache as kv_utils | 
|  | 22 | +import ai_edge_torch.generative.layers.model_config as cfg | 
|  | 23 | +import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb | 
|  | 24 | +from ai_edge_torch.generative.utilities import model_builder | 
|  | 25 | +import ai_edge_torch.generative.utilities.loader as loading_utils | 
|  | 26 | +import torch | 
|  | 27 | + | 
|  | 28 | +TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( | 
|  | 29 | +    ff_up_proj="language_model.model.layers.{}.mlp.up_proj", | 
|  | 30 | +    ff_down_proj="language_model.model.layers.{}.mlp.down_proj", | 
|  | 31 | +    ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj", | 
|  | 32 | +    attn_query_proj="language_model.model.layers.{}.self_attn.q_proj", | 
|  | 33 | +    attn_key_proj="language_model.model.layers.{}.self_attn.k_proj", | 
|  | 34 | +    attn_value_proj="language_model.model.layers.{}.self_attn.v_proj", | 
|  | 35 | +    attn_output_proj="language_model.model.layers.{}.self_attn.o_proj", | 
|  | 36 | +    pre_attn_norm="language_model.model.layers.{}.input_layernorm", | 
|  | 37 | +    post_attn_norm="language_model.model.layers.{}.post_attention_layernorm", | 
|  | 38 | +    pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm", | 
|  | 39 | +    post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm", | 
|  | 40 | +    embedding="language_model.model.embed_tokens", | 
|  | 41 | +    final_norm="language_model.model.norm", | 
|  | 42 | +    lm_head=None, | 
|  | 43 | +) | 
|  | 44 | + | 
|  | 45 | + | 
|  | 46 | +class Decoder2(gemma2.Gemma2): | 
|  | 47 | +  """A decoder of PaliGemma2 3B model which is Gemma2. | 
|  | 48 | +
 | 
|  | 49 | +  Besides a tensor of text token IDs, forward() can also take a tensor of | 
|  | 50 | +  embeddings which may include text or image or both. | 
|  | 51 | +  """ | 
|  | 52 | + | 
|  | 53 | +  @torch.inference_mode | 
|  | 54 | +  def forward( | 
|  | 55 | +      self, | 
|  | 56 | +      tokens: torch.Tensor, | 
|  | 57 | +      input_pos: torch.Tensor, | 
|  | 58 | +      kv_cache: kv_utils.KVCache, | 
|  | 59 | +      input_embeds: torch.Tensor = None, | 
|  | 60 | +      export_config: Optional[model_builder.ExportConfig] = None, | 
|  | 61 | +      called_by_generate: bool = True, | 
|  | 62 | +  ) -> dict[torch.Tensor, kv_utils.KVCache]: | 
|  | 63 | +    if input_embeds is None: | 
|  | 64 | +      return super().forward(tokens, input_pos, kv_cache) | 
|  | 65 | + | 
|  | 66 | +    assert input_embeds is not None | 
|  | 67 | + | 
|  | 68 | +    repo_pos = input_pos + 1  # PaliGemma2 position is 1-based. | 
|  | 69 | +    # ROPE parameters for all attn_configs are the same. Take the first one. | 
|  | 70 | +    attn_config = self.config.block_config(0).attn_config | 
|  | 71 | +    n_elem = int(attn_config.rotary_percentage * attn_config.head_dim) | 
|  | 72 | +    rope = rotary_pos_emb.build_rope( | 
|  | 73 | +        repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base | 
|  | 74 | +    ) | 
|  | 75 | + | 
|  | 76 | +    if called_by_generate: | 
|  | 77 | +      # PaliGemma2 generate() use a diagonal causal mask even with image embeds. | 
|  | 78 | +      mask = [self.get_attention_mask( | 
|  | 79 | +          self.config.block_config(i).attn_config.attn_type, input_pos | 
|  | 80 | +      ) for i in range(self.config.num_layers)] | 
|  | 81 | +    else: | 
|  | 82 | +      # By default, don't mask image embeds with a diagonal causal mask. | 
|  | 83 | +      embeds_len = input_embeds.shape[1] | 
|  | 84 | +      mask = torch.zeros(embeds_len, self.config.kv_cache_max) | 
|  | 85 | +      mask[:, embeds_len:] = float("-inf") | 
|  | 86 | +      mask = [mask] * self.config.num_layers | 
|  | 87 | + | 
|  | 88 | +    return self._forward_with_embeds( | 
|  | 89 | +        input_embeds, rope, mask, input_pos, kv_cache, export_config | 
|  | 90 | +    ) | 
|  | 91 | + | 
|  | 92 | + | 
|  | 93 | +def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: | 
|  | 94 | +  """Returns the model config for the decoder of a PaliGemma 3B model. | 
|  | 95 | +
 | 
|  | 96 | +  Args: | 
|  | 97 | +    kv_cache_max_len (int): The maximum sequence length of the KV cache. Default | 
|  | 98 | +      is 1024. | 
|  | 99 | +
 | 
|  | 100 | +  Returns: | 
|  | 101 | +    The model config for the decoder of a PaliGemma 3B model. | 
|  | 102 | +  """ | 
|  | 103 | +  norm_config = cfg.NormalizationConfig( | 
|  | 104 | +      type=cfg.NormalizationType.RMS_NORM, | 
|  | 105 | +      epsilon=1e-6, | 
|  | 106 | +      zero_centered=True, | 
|  | 107 | +  ) | 
|  | 108 | +  ff_config = cfg.FeedForwardConfig( | 
|  | 109 | +      type=cfg.FeedForwardType.GATED, | 
|  | 110 | +      activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH), | 
|  | 111 | +      intermediate_size=9216, | 
|  | 112 | +      pre_ff_norm_config=norm_config, | 
|  | 113 | +      post_ff_norm_config=norm_config, | 
|  | 114 | +  ) | 
|  | 115 | + | 
|  | 116 | +  def get_block_config(idx: int) -> cfg.TransformerBlockConfig: | 
|  | 117 | +    attn_config = cfg.AttentionConfig( | 
|  | 118 | +        num_heads=8, | 
|  | 119 | +        head_dim=256, | 
|  | 120 | +        num_query_groups=4, | 
|  | 121 | +        rotary_base=10000, | 
|  | 122 | +        rotary_percentage=1.0, | 
|  | 123 | +        logit_softcap=50.0, | 
|  | 124 | +        sliding_window_size=4096, | 
|  | 125 | +        attn_type=( | 
|  | 126 | +            cfg.AttentionType.GLOBAL | 
|  | 127 | +            if idx % 2 == 0 | 
|  | 128 | +            else cfg.AttentionType.LOCAL_SLIDING | 
|  | 129 | +        ), | 
|  | 130 | +    ) | 
|  | 131 | +    return cfg.TransformerBlockConfig( | 
|  | 132 | +        attn_config=attn_config, | 
|  | 133 | +        ff_config=ff_config, | 
|  | 134 | +        pre_attention_norm_config=norm_config, | 
|  | 135 | +        post_attention_norm_config=norm_config, | 
|  | 136 | +    ) | 
|  | 137 | + | 
|  | 138 | +  num_layers = 26 | 
|  | 139 | +  embedding_dim = 2304 | 
|  | 140 | +  config = cfg.ModelConfig( | 
|  | 141 | +      vocab_size=257216, | 
|  | 142 | +      num_layers=num_layers, | 
|  | 143 | +      max_seq_len=8192, | 
|  | 144 | +      embedding_dim=embedding_dim, | 
|  | 145 | +      embedding_scale=embedding_dim**0.5, | 
|  | 146 | +      kv_cache_max_len=kv_cache_max_len, | 
|  | 147 | +      block_configs=[get_block_config(i) for i in range(num_layers)], | 
|  | 148 | +      final_norm_config=norm_config, | 
|  | 149 | +      lm_head_use_bias=False, | 
|  | 150 | +      enable_hlfb=True, | 
|  | 151 | +      final_logit_softcap=30.0, | 
|  | 152 | +  ) | 
|  | 153 | +  return config | 
|  | 154 | + | 
|  | 155 | + | 
|  | 156 | +def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: | 
|  | 157 | +  config = get_decoder2_config(kv_cache_max_len) | 
|  | 158 | +  # PaliGemma2 decoder has only one block config. | 
|  | 159 | +  config.block_config(0).ff_config.intermediate_size = 128 | 
|  | 160 | +  config.vocab_size = 128 | 
|  | 161 | +  config.num_layers = 2 | 
|  | 162 | +  config.max_seq_len = 2 * kv_cache_max_len | 
|  | 163 | +  return config | 
|  | 164 | + | 
|  | 165 | + | 
|  | 166 | +def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module: | 
|  | 167 | +  return model_builder.build_decoder_only_model( | 
|  | 168 | +      checkpoint_path=checkpoint_path, | 
|  | 169 | +      config=get_decoder2_config(**kwargs), | 
|  | 170 | +      tensor_names=TENSOR_NAMES, | 
|  | 171 | +      model_class=Decoder2, | 
|  | 172 | +  ) | 
0 commit comments