|  | 
|  | 1 | +# Copyright 2024 The AI Edge Torch Authors. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +# ============================================================================== | 
|  | 15 | + | 
|  | 16 | +"""Example of building a Phi-4 model up to 4K tokens, not to 128K tokens.""" | 
|  | 17 | + | 
|  | 18 | +from functools import partial | 
|  | 19 | +import math | 
|  | 20 | +from typing import Tuple | 
|  | 21 | + | 
|  | 22 | +import ai_edge_torch.generative.layers.model_config as cfg | 
|  | 23 | +from ai_edge_torch.generative.utilities import model_builder | 
|  | 24 | +import ai_edge_torch.generative.utilities.loader as loading_utils | 
|  | 25 | +import torch | 
|  | 26 | + | 
|  | 27 | +TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( | 
|  | 28 | +    ff_up_proj="model.layers.{}.mlp.gate_up_proj", | 
|  | 29 | +    ff_down_proj="model.layers.{}.mlp.down_proj", | 
|  | 30 | +    attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj", | 
|  | 31 | +    attn_output_proj="model.layers.{}.self_attn.o_proj", | 
|  | 32 | +    pre_attn_norm="model.layers.{}.input_layernorm", | 
|  | 33 | +    post_attn_norm="model.layers.{}.post_attention_layernorm", | 
|  | 34 | +    embedding="model.embed_tokens", | 
|  | 35 | +    final_norm="model.norm", | 
|  | 36 | +) | 
|  | 37 | + | 
|  | 38 | +# max_position_embeddings / original_max_position_embeddings in Phi-4 config. | 
|  | 39 | +ROPE_SCALE_FACTOR = 32 | 
|  | 40 | + | 
|  | 41 | +# ROPE short factor in Phi-4 config. According to LOPE paper and its code in | 
|  | 42 | +# https://github.com/microsoft/LongRoPE, these values had been searched with | 
|  | 43 | +# min=1.0, step-0.01 to optimize the errors of sample dataset. | 
|  | 44 | +ROPE_SHORT_FACTOR = [1.0] * 48 | 
|  | 45 | + | 
|  | 46 | + | 
|  | 47 | +def _build_phi4_rope( | 
|  | 48 | +    input_pos: int, | 
|  | 49 | +    n_elem: int, | 
|  | 50 | +    base: int, | 
|  | 51 | +    condense_ratio: int, | 
|  | 52 | +    dtype: torch.dtype, | 
|  | 53 | +    device: torch.device, | 
|  | 54 | +    theta_factors: torch.Tensor, | 
|  | 55 | +    scale: float, | 
|  | 56 | +) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | 57 | +  """Computes Rotary Positional Embeddings for Phi-4 model. | 
|  | 58 | +
 | 
|  | 59 | +  It's a modified version of attn_utils.build_rope_cache with additional | 
|  | 60 | +  arguments for Phi-4 model. It precompute Rotary Positional Embedding Sin and | 
|  | 61 | +  Cos values with scaling factors for quick lookup during the inference. | 
|  | 62 | +
 | 
|  | 63 | +  Args: | 
|  | 64 | +      input_pos (torch.Tensor): the given input sequence positions | 
|  | 65 | +      n_elem (int): Each sequence's dimmension. | 
|  | 66 | +      base (int, optional): Rope base value. | 
|  | 67 | +      condense_ratio (int, optional): The ratio by which sequence indicies are | 
|  | 68 | +        condensed. | 
|  | 69 | +      dtype (torch.dtype, optional): Output tensor's data type. | 
|  | 70 | +      device (torch.device, optional): Output tensor's data type. | 
|  | 71 | +      theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used | 
|  | 72 | +        to scale the theta values. | 
|  | 73 | +      scale (float, optional): A float used to scale the rope values. | 
|  | 74 | +
 | 
|  | 75 | +  Returns: | 
|  | 76 | +      Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves. | 
|  | 77 | +  """ | 
|  | 78 | +  theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem)) | 
|  | 79 | +  theta = theta / theta_factors | 
|  | 80 | +  seq_idx = input_pos / condense_ratio | 
|  | 81 | +  idx_theta = torch.outer(seq_idx, theta) | 
|  | 82 | +  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale | 
|  | 83 | +  sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale | 
|  | 84 | +  return cos, sin | 
|  | 85 | + | 
|  | 86 | + | 
|  | 87 | +class Phi4Mini(model_builder.DecoderOnlyModel): | 
|  | 88 | +  """A Phi-4 model built from the Edge Generative API layers.""" | 
|  | 89 | +  pass | 
|  | 90 | + | 
|  | 91 | + | 
|  | 92 | +def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: | 
|  | 93 | +  """Returns the model config for a Phi-4 model. | 
|  | 94 | +
 | 
|  | 95 | +  Args: | 
|  | 96 | +    kv_cache_max_len (int): The maximum sequence length of the KV cache. Default | 
|  | 97 | +      is 1024. | 
|  | 98 | +
 | 
|  | 99 | +  Returns: | 
|  | 100 | +    The model config for a Phi-4 model. | 
|  | 101 | +  """ | 
|  | 102 | +  attn_config = cfg.AttentionConfig( | 
|  | 103 | +      num_heads=24, | 
|  | 104 | +      head_dim=128, | 
|  | 105 | +      num_query_groups=8, | 
|  | 106 | +      rotary_base=10000, | 
|  | 107 | +      rotary_percentage=0.75, | 
|  | 108 | +      qkv_transpose_before_split=True, | 
|  | 109 | +  ) | 
|  | 110 | +  ff_config = cfg.FeedForwardConfig( | 
|  | 111 | +      type=cfg.FeedForwardType.SEQUENTIAL, | 
|  | 112 | +      activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU), | 
|  | 113 | +      intermediate_size=8192, | 
|  | 114 | +  ) | 
|  | 115 | +  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM) | 
|  | 116 | +  block_config = cfg.TransformerBlockConfig( | 
|  | 117 | +      attn_config=attn_config, | 
|  | 118 | +      ff_config=ff_config, | 
|  | 119 | +      pre_attention_norm_config=norm_config, | 
|  | 120 | +      post_attention_norm_config=norm_config, | 
|  | 121 | +  ) | 
|  | 122 | + | 
|  | 123 | +  max_seq_len = 4096 | 
|  | 124 | +  # Create the RoPE callable | 
|  | 125 | +  build_rope = partial( | 
|  | 126 | +      _build_phi4_rope, | 
|  | 127 | +      condense_ratio=1, | 
|  | 128 | +      dtype=torch.float32, | 
|  | 129 | +      device=torch.device("cpu"), | 
|  | 130 | +      theta_factors=torch.tensor(ROPE_SHORT_FACTOR), | 
|  | 131 | +      scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)), | 
|  | 132 | +  ) | 
|  | 133 | + | 
|  | 134 | +  config = cfg.ModelConfig( | 
|  | 135 | +      vocab_size=200064, | 
|  | 136 | +      num_layers=32, | 
|  | 137 | +      max_seq_len=max_seq_len, | 
|  | 138 | +      kv_cache_max_len=kv_cache_max_len, | 
|  | 139 | +      embedding_dim=3072, | 
|  | 140 | +      block_configs=block_config, | 
|  | 141 | +      final_norm_config=norm_config, | 
|  | 142 | +      enable_hlfb=True, | 
|  | 143 | +      build_rope=build_rope, | 
|  | 144 | +  ) | 
|  | 145 | +  return config | 
|  | 146 | + | 
|  | 147 | + | 
|  | 148 | +def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: | 
|  | 149 | +  config = get_model_config(kv_cache_max_len) | 
|  | 150 | +  config.vocab_size = 128 | 
|  | 151 | +  config.num_layers = 2 | 
|  | 152 | +  config.max_seq_len = 2 * kv_cache_max_len | 
|  | 153 | +  # Phi-4 has only one block config. | 
|  | 154 | +  config.block_config(0).ff_config.intermediate_size = 128 | 
|  | 155 | +  return config | 
|  | 156 | + | 
|  | 157 | + | 
|  | 158 | +def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module: | 
|  | 159 | +  """Instantiates the model instance and load checkpoint if provided.""" | 
|  | 160 | +  return model_builder.build_decoder_only_model( | 
|  | 161 | +      checkpoint_path=checkpoint_path, | 
|  | 162 | +      config=get_model_config(**kwargs), | 
|  | 163 | +      tensor_names=TENSOR_NAMES, | 
|  | 164 | +      model_class=Phi4Mini, | 
|  | 165 | +  ) | 
0 commit comments