Skip to content

Commit 29311c6

Browse files
ai-edge-botcopybara-github
authored andcommitted
Support Phi-4 model
PiperOrigin-RevId: 732233721
1 parent 883075e commit 29311c6

File tree

6 files changed

+334
-12
lines changed

6 files changed

+334
-12
lines changed

ai_edge_torch/generative/examples/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ found [here](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/tree/main).
2626
## TinyLlama
2727
[TinyLlama](https://github.com/jzhang38/TinyLlama) is a popular OSS smaller version of Meta's Llama2 model, with only 1.1B parameters. [HuggingFace checkpoint](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
2828

29-
## Microsoft Phi-2 and 3.5-mini
30-
Microsoft Phi-2 and Phi-3.5-mini are also decoder-only LLMs with 2.7B and 3.82B
31-
parameters each. See details on
32-
[Kaggle](https://www.kaggle.com/models/Microsoft/phi/transformers/2) for Phi-2
33-
and [HuggingFace](https://huggingface.co/microsoft/Phi-3.5-mini-instruct) for
34-
Phi-3.5-mini. Note that the example of Phi-3.5-mini supports up to 4K tokens,
35-
not to 128K tokens which the original Phi-3.5 supports.
29+
## Microsoft Phi-2, 3.5-mini, and 4-mini
30+
Microsoft Phi-2, Phi-3.5-mini and Phi-4-mini are also decoder-only LLMs with
31+
2.7B, 3.82B and 3.84B parameters each. See details on
32+
[Kaggle](https://www.kaggle.com/models/Microsoft/phi/transformers/2) for Phi-2,
33+
[HuggingFace](https://huggingface.co/microsoft/Phi-3.5-mini-instruct) for Phi-3.5-mini,
34+
and [HuggingFace](https://huggingface.co/microsoft/Phi-4-mini-instruct) for Phi-4-mini.
35+
Note that the example of Phi-3.5-mini and Phi-4-mini supports up to 4K tokens,
36+
not to 128K tokens which the original models support.
3637

3738
## Apple OpenELM
3839
[Apple OpenELM](https://huggingface.co/apple/OpenELM) is also a decoder-only LLM
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of converting a Phi-4 model to multi-signature tflite model."""
17+
18+
import os
19+
import pathlib
20+
21+
from absl import app
22+
from absl import flags
23+
from ai_edge_torch.generative.examples.phi import phi4
24+
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26+
27+
_CHECKPOINT_PATH = flags.DEFINE_string(
28+
'checkpoint_path',
29+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi4'),
30+
'The path to the model checkpoint, or directory holding the checkpoint.',
31+
)
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
34+
'/tmp/',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'phi4',
40+
'The prefix of the output tflite model name.',
41+
)
42+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43+
'prefill_seq_lens',
44+
(8, 64, 128, 256, 512, 1024),
45+
'List of the maximum sizes of prefill input tensors.',
46+
)
47+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48+
'kv_cache_max_len',
49+
1280,
50+
'The maximum size of KV cache buffer, including both prefill and decode.',
51+
)
52+
_QUANTIZE = flags.DEFINE_bool(
53+
'quantize',
54+
True,
55+
'Whether the model should be quantized.',
56+
)
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
62+
63+
64+
def main(_):
65+
pytorch_model = phi4.build_model(
66+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67+
)
68+
converter.convert_to_tflite(
69+
pytorch_model,
70+
output_path=_OUTPUT_PATH.value,
71+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
73+
quantize=_QUANTIZE.value,
74+
lora_ranks=_LORA_RANKS.value,
75+
export_config=ExportConfig(),
76+
)
77+
78+
79+
if __name__ == '__main__':
80+
app.run(main)

ai_edge_torch/generative/examples/phi/phi3.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@ def _build_phi3_rope(
136136

137137
class Phi3_5Mini(model_builder.DecoderOnlyModel):
138138
"""A Phi-3.5 model built from the Edge Generative API layers."""
139-
140-
def __init__(self, config: cfg.ModelConfig):
141-
super().__init__(config)
142-
attn_config = self.config.block_config(0).attn_config
139+
pass
143140

144141

145142
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -150,7 +147,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150147
is 1024.
151148
152149
Returns:
153-
The model config for a Phi-2 model.
150+
The model config for a Phi-3.5 model.
154151
"""
155152
attn_config = cfg.AttentionConfig(
156153
num_heads=32,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of building a Phi-4 model up to 4K tokens, not to 128K tokens."""
17+
18+
from functools import partial
19+
import math
20+
from typing import Tuple
21+
22+
import ai_edge_torch.generative.layers.model_config as cfg
23+
from ai_edge_torch.generative.utilities import model_builder
24+
import ai_edge_torch.generative.utilities.loader as loading_utils
25+
import torch
26+
27+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28+
ff_up_proj="model.layers.{}.mlp.gate_up_proj",
29+
ff_down_proj="model.layers.{}.mlp.down_proj",
30+
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
31+
attn_output_proj="model.layers.{}.self_attn.o_proj",
32+
pre_attn_norm="model.layers.{}.input_layernorm",
33+
post_attn_norm="model.layers.{}.post_attention_layernorm",
34+
embedding="model.embed_tokens",
35+
final_norm="model.norm",
36+
)
37+
38+
# max_position_embeddings / original_max_position_embeddings in Phi-4 config.
39+
ROPE_SCALE_FACTOR = 32
40+
41+
# ROPE short factor in Phi-4 config. According to LOPE paper and its code in
42+
# https://github.com/microsoft/LongRoPE, these values had been searched with
43+
# min=1.0, step-0.01 to optimize the errors of sample dataset.
44+
ROPE_SHORT_FACTOR = [1.0] * 48
45+
46+
47+
def _build_phi4_rope(
48+
input_pos: int,
49+
n_elem: int,
50+
base: int,
51+
condense_ratio: int,
52+
dtype: torch.dtype,
53+
device: torch.device,
54+
theta_factors: torch.Tensor,
55+
scale: float,
56+
) -> Tuple[torch.Tensor, torch.Tensor]:
57+
"""Computes Rotary Positional Embeddings for Phi-4 model.
58+
59+
It's a modified version of attn_utils.build_rope_cache with additional
60+
arguments for Phi-4 model. It precompute Rotary Positional Embedding Sin and
61+
Cos values with scaling factors for quick lookup during the inference.
62+
63+
Args:
64+
input_pos (torch.Tensor): the given input sequence positions
65+
n_elem (int): Each sequence's dimmension.
66+
base (int, optional): Rope base value.
67+
condense_ratio (int, optional): The ratio by which sequence indicies are
68+
condensed.
69+
dtype (torch.dtype, optional): Output tensor's data type.
70+
device (torch.device, optional): Output tensor's data type.
71+
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
72+
to scale the theta values.
73+
scale (float, optional): A float used to scale the rope values.
74+
75+
Returns:
76+
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
77+
"""
78+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
79+
theta = theta / theta_factors
80+
seq_idx = input_pos / condense_ratio
81+
idx_theta = torch.outer(seq_idx, theta)
82+
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
83+
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
84+
return cos, sin
85+
86+
87+
class Phi4Mini(model_builder.DecoderOnlyModel):
88+
"""A Phi-4 model built from the Edge Generative API layers."""
89+
pass
90+
91+
92+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
93+
"""Returns the model config for a Phi-4 model.
94+
95+
Args:
96+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
97+
is 1024.
98+
99+
Returns:
100+
The model config for a Phi-4 model.
101+
"""
102+
attn_config = cfg.AttentionConfig(
103+
num_heads=24,
104+
head_dim=128,
105+
num_query_groups=8,
106+
rotary_base=10000,
107+
rotary_percentage=0.75,
108+
qkv_transpose_before_split=True,
109+
)
110+
ff_config = cfg.FeedForwardConfig(
111+
type=cfg.FeedForwardType.SEQUENTIAL,
112+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
113+
intermediate_size=8192,
114+
)
115+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
116+
block_config = cfg.TransformerBlockConfig(
117+
attn_config=attn_config,
118+
ff_config=ff_config,
119+
pre_attention_norm_config=norm_config,
120+
post_attention_norm_config=norm_config,
121+
)
122+
123+
max_seq_len = 4096
124+
# Create the RoPE callable
125+
build_rope = partial(
126+
_build_phi4_rope,
127+
condense_ratio=1,
128+
dtype=torch.float32,
129+
device=torch.device("cpu"),
130+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
131+
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
132+
)
133+
134+
config = cfg.ModelConfig(
135+
vocab_size=200064,
136+
num_layers=32,
137+
max_seq_len=max_seq_len,
138+
kv_cache_max_len=kv_cache_max_len,
139+
embedding_dim=3072,
140+
block_configs=block_config,
141+
final_norm_config=norm_config,
142+
enable_hlfb=True,
143+
build_rope=build_rope,
144+
)
145+
return config
146+
147+
148+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
149+
config = get_model_config(kv_cache_max_len)
150+
config.vocab_size = 128
151+
config.num_layers = 2
152+
config.max_seq_len = 2 * kv_cache_max_len
153+
# Phi-4 has only one block config.
154+
config.block_config(0).ff_config.intermediate_size = 128
155+
return config
156+
157+
158+
def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
159+
"""Instantiates the model instance and load checkpoint if provided."""
160+
return model_builder.build_decoder_only_model(
161+
checkpoint_path=checkpoint_path,
162+
config=get_model_config(**kwargs),
163+
tensor_names=TENSOR_NAMES,
164+
model_class=Phi4Mini,
165+
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Verifies the reauthored Phi-4 model."""
17+
18+
import logging
19+
import pathlib
20+
21+
from absl import app
22+
from absl import flags
23+
from ai_edge_torch.generative.examples.phi import phi4
24+
from ai_edge_torch.generative.utilities import transformers_verifier
25+
from ai_edge_torch.generative.utilities import verifier
26+
import transformers
27+
28+
29+
_PROMPTS = flags.DEFINE_multi_string(
30+
"prompts",
31+
"Instruct: Write an email about the weather Output:",
32+
"The input prompts to generate answers.",
33+
)
34+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
35+
"max_new_tokens",
36+
30,
37+
"The maximum size of the generated tokens.",
38+
)
39+
40+
41+
def main(_):
42+
checkpoint = "microsoft/Phi-4-mini-instruct"
43+
logging.info("Loading the original model from: %s", checkpoint)
44+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45+
46+
# Locate the cached dir.
47+
cached_config_file = transformers.utils.cached_file(
48+
checkpoint, transformers.utils.CONFIG_NAME
49+
)
50+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52+
reauthored_model = phi4.build_model(reauthored_checkpoint)
53+
54+
logging.info("Loading the tokenizer from: %s", checkpoint)
55+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56+
57+
verifier.verify_reauthored_model(
58+
original_model=transformers_verifier.TransformersModelWrapper(
59+
original_model
60+
),
61+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62+
tokenizer=verifier.TokenizerWrapper(tokenizer),
63+
generate_prompts=_PROMPTS.value,
64+
max_new_tokens=_MAX_NEW_TOKENS.value,
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
app.run(main)

ai_edge_torch/generative/test/test_model_conversion_large.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ai_edge_torch.generative.examples.paligemma import paligemma
2828
from ai_edge_torch.generative.examples.phi import phi2
2929
from ai_edge_torch.generative.examples.phi import phi3
30+
from ai_edge_torch.generative.examples.phi import phi4
3031
from ai_edge_torch.generative.examples.qwen import qwen
3132
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
3233
from ai_edge_torch.generative.examples.smollm import smollm
@@ -139,6 +140,15 @@ def test_phi3(self):
139140
pytorch_model = phi3.Phi3_5Mini(config).eval()
140141
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
141142

143+
@googletest.skipIf(
144+
ai_edge_torch.config.in_oss,
145+
reason="tests with custom ops are not supported in oss",
146+
)
147+
def test_phi4(self):
148+
config = phi4.get_fake_model_config()
149+
pytorch_model = phi4.Phi4Mini(config).eval()
150+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
151+
142152
@googletest.skipIf(
143153
ai_edge_torch.config.in_oss,
144154
reason="tests with custom ops are not supported in oss",

0 commit comments

Comments
 (0)