Skip to content

Commit 42eb2ef

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add DeepSeek-R1-Distill-Qwen as an example
- It's based on Qwen with some slight changes including a separate lm_head weights. PiperOrigin-RevId: 719455977
1 parent 25db678 commit 42eb2ef

File tree

6 files changed

+271
-2
lines changed

6 files changed

+271
-2
lines changed

ai_edge_torch/generative/examples/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ same architecture as SmolLM but it has been trained on improved training data.
5252
Alibaba's [Qwen 2.5](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e)
5353
0.5B, 1B, 3B modes are also provided as examples.
5454

55+
## DeepSeek
56+
[DeepSeek-R1 distilled](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)
57+
model based on Qwen is also provided as an example.
58+
5559
## AMD-Llama-135m
5660

5761
[AMD-Llama-135m](https://huggingface.co/amd/AMD-Llama-135m) is a 135M parameter model based on the Llama2 architecture and uses the same tokenizer as Llama2. It was trained on AMD Instinct MI250 accelerators.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of converting DeepSeek R1 distilled models to tflite model."""
17+
18+
import os
19+
import pathlib
20+
21+
from absl import app
22+
from absl import flags
23+
from ai_edge_torch.generative.examples.deepseek import deepseek
24+
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26+
27+
_CHECKPOINT_PATH = flags.DEFINE_string(
28+
'checkpoint_path',
29+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/deepseek'),
30+
'The path to the model checkpoint, or directory holding the checkpoint.',
31+
)
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
34+
'/tmp/',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'deepseek',
40+
'The prefix of the output tflite model name.',
41+
)
42+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43+
'prefill_seq_lens',
44+
(8, 64, 128, 256, 512, 1024),
45+
'List of the maximum sizes of prefill input tensors.',
46+
)
47+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48+
'kv_cache_max_len',
49+
1280,
50+
'The maximum size of KV cache buffer, including both prefill and decode.',
51+
)
52+
_QUANTIZE = flags.DEFINE_bool(
53+
'quantize',
54+
True,
55+
'Whether the model should be quantized.',
56+
)
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
62+
63+
64+
def main(_):
65+
pytorch_model = deepseek.build_model(
66+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67+
)
68+
converter.convert_to_tflite(
69+
pytorch_model,
70+
output_path=_OUTPUT_PATH.value,
71+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
73+
quantize=_QUANTIZE.value,
74+
lora_ranks=_LORA_RANKS.value,
75+
export_config=ExportConfig(),
76+
)
77+
78+
79+
if __name__ == '__main__':
80+
app.run(main)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of building DeepSeek R1 distilled models."""
17+
18+
import ai_edge_torch.generative.layers.model_config as cfg
19+
from ai_edge_torch.generative.utilities import model_builder
20+
from torch import nn
21+
22+
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
23+
24+
25+
class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
26+
"""A DeepSeek distilled model based on Qwen."""
27+
pass
28+
29+
30+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31+
"""Returns the model config for a Qwen 2.5 3B model.
32+
33+
Args:
34+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
35+
is 1024.
36+
37+
Returns:
38+
The model config for a SmolLM model.
39+
"""
40+
attn_config = cfg.AttentionConfig(
41+
num_heads=12,
42+
head_dim=128,
43+
num_query_groups=2,
44+
rotary_base=10000,
45+
rotary_percentage=1.0,
46+
qkv_use_bias=True,
47+
)
48+
ff_config = cfg.FeedForwardConfig(
49+
type=cfg.FeedForwardType.GATED,
50+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
51+
intermediate_size=8960,
52+
)
53+
norm_config = cfg.NormalizationConfig(
54+
type=cfg.NormalizationType.RMS_NORM,
55+
epsilon=1e-06,
56+
)
57+
block_config = cfg.TransformerBlockConfig(
58+
attn_config=attn_config,
59+
ff_config=ff_config,
60+
pre_attention_norm_config=norm_config,
61+
post_attention_norm_config=norm_config,
62+
)
63+
config = cfg.ModelConfig(
64+
vocab_size=151936,
65+
num_layers=28,
66+
max_seq_len=4096,
67+
embedding_dim=1536,
68+
kv_cache_max_len=kv_cache_max_len,
69+
block_configs=block_config,
70+
final_norm_config=norm_config,
71+
lm_head_share_weight_with_embedding=False,
72+
enable_hlfb=True,
73+
)
74+
return config
75+
76+
77+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
78+
config = get_model_config(**kwargs)
79+
config.vocab_size = 128
80+
config.num_layers = 2
81+
# DeepSeek-R1-Distill-Qwen has only one block config.
82+
config.block_config(0).ff_config.intermediate_size = 64
83+
return config
84+
85+
86+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
87+
return model_builder.build_decoder_only_model(
88+
checkpoint_path=checkpoint_path,
89+
config=get_model_config(**kwargs),
90+
tensor_names=TENSOR_NAMES,
91+
model_class=DeepSeekDistillQwen,
92+
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
17+
18+
import logging
19+
import pathlib
20+
21+
from absl import app
22+
from absl import flags
23+
from ai_edge_torch.generative.examples.deepseek import deepseek
24+
from ai_edge_torch.generative.utilities import transformers_verifier
25+
from ai_edge_torch.generative.utilities import verifier
26+
import transformers
27+
28+
29+
_PROMPTS = flags.DEFINE_multi_string(
30+
"prompts",
31+
"What is the meaning of life?",
32+
"The input prompts to generate answers.",
33+
)
34+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
35+
"max_new_tokens",
36+
30,
37+
"The maximum size of the generated tokens.",
38+
)
39+
40+
41+
def main(_):
42+
checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43+
logging.info("Loading the original model from: %s", checkpoint)
44+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45+
46+
# Locate the cached dir.
47+
cached_config_file = transformers.utils.cached_file(
48+
checkpoint, transformers.utils.CONFIG_NAME
49+
)
50+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52+
reauthored_model = deepseek.build_model(reauthored_checkpoint)
53+
54+
logging.info("Loading the tokenizer from: %s", checkpoint)
55+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56+
57+
verifier.verify_reauthored_model(
58+
original_model=transformers_verifier.TransformersModelWrapper(
59+
original_model
60+
),
61+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62+
tokenizer=verifier.TokenizerWrapper(tokenizer),
63+
generate_prompts=_PROMPTS.value,
64+
max_new_tokens=_MAX_NEW_TOKENS.value,
65+
atol=1e-04,
66+
)
67+
68+
69+
if __name__ == "__main__":
70+
app.run(main)

ai_edge_torch/generative/test/test_model_conversion_large.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import ai_edge_torch
1919
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20+
from ai_edge_torch.generative.examples.deepseek import deepseek
2021
from ai_edge_torch.generative.examples.gemma import gemma1
2122
from ai_edge_torch.generative.examples.gemma import gemma2
2223
from ai_edge_torch.generative.examples.llama import llama
@@ -150,16 +151,15 @@ def test_smollm(self):
150151
ai_edge_torch.config.in_oss,
151152
reason="tests with custom ops are not supported in oss",
152153
)
153-
154154
def test_smollm2(self):
155155
config = smollm.get_fake_model_config_v2()
156156
pytorch_model = smollm.SmolLM2(config).eval()
157157
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158+
158159
@googletest.skipIf(
159160
ai_edge_torch.config.in_oss,
160161
reason="tests with custom ops are not supported in oss",
161162
)
162-
163163
def test_openelm(self):
164164
config = openelm.get_fake_model_config()
165165
pytorch_model = openelm.OpenELM(config).eval()
@@ -174,6 +174,15 @@ def test_qwen(self):
174174
pytorch_model = qwen.Qwen(config).eval()
175175
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
176176

177+
@googletest.skipIf(
178+
ai_edge_torch.config.in_oss,
179+
reason="tests with custom ops are not supported in oss",
180+
)
181+
def test_deepseek(self):
182+
config = deepseek.get_fake_model_config()
183+
pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
184+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
185+
177186
@googletest.skipIf(
178187
ai_edge_torch.config.in_oss,
179188
reason="tests with custom ops are not supported in oss",

0 commit comments

Comments
 (0)