Skip to content

Commit 4bf9d76

Browse files
ai-edge-botcopybara-github
authored andcommitted
Smollm2 implementation for ai_torch_edge.
PiperOrigin-RevId: 713507777
1 parent 85446ef commit 4bf9d76

File tree

5 files changed

+140
-3
lines changed

5 files changed

+140
-3
lines changed

ai_edge_torch/generative/examples/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ with 270M, 450M, 1.1B, and 3B parameters. The example we provide is OpenELM 3B,
4040
and the checkpoint for the model can be found
4141
[here](https://huggingface.co/apple/OpenELM-3B/tree/main).
4242

43-
## HuggingFace SmolLM
43+
## HuggingFace SmolLM and SmolLM2
4444
[HuggingFace SmolLM](https://huggingface.co/blog/smollm) is also a decoder-only
4545
LLM with 135M, 360M, 1.7B parameters. The example we provide is SmolLM 135M, and
4646
the checkpoint for the model can be found
4747
[here](https://huggingface.co/HuggingFaceTB/SmolLM-135M).
48+
Similarly [SmolLM2](https://huggingface.co/HuggingFaceTB/SmolLM2-135M) has the
49+
same architecture as SmolLM but it has been trained on improved training data.
4850

4951
## Qwen
5052
Alibaba's [Qwen 2.5](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of converting SmolLM2 model to multi-signature tflite model."""
17+
18+
import os
19+
import pathlib
20+
21+
from absl import app
22+
from absl import flags
23+
from ai_edge_torch.generative.examples.smollm import smollm
24+
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26+
27+
_CHECKPOINT_PATH = flags.DEFINE_string(
28+
'checkpoint_path',
29+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
30+
'The path to the model checkpoint, or directory holding the checkpoint.',
31+
)
32+
_TFLITE_PATH = flags.DEFINE_string(
33+
'tflite_path',
34+
'/tmp/',
35+
'The tflite file path to export.',
36+
)
37+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38+
'prefill_seq_lens',
39+
(8, 64, 128, 256, 512, 1024),
40+
'List of the maximum sizes of prefill input tensors.',
41+
)
42+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
43+
'kv_cache_max_len',
44+
1280,
45+
'The maximum size of KV cache buffer, including both prefill and decode.',
46+
)
47+
_QUANTIZE = flags.DEFINE_bool(
48+
'quantize',
49+
True,
50+
'Whether the model should be quantized.',
51+
)
52+
53+
54+
def main(_):
55+
pytorch_model = smollm.build_model_v2(
56+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57+
)
58+
59+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60+
output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61+
converter.convert_to_tflite(
62+
pytorch_model,
63+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
64+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
65+
quantize=_QUANTIZE.value,
66+
export_config=ExportConfig(),
67+
)
68+
69+
70+
if __name__ == '__main__':
71+
app.run(main)

ai_edge_torch/generative/examples/smollm/smollm.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
8585
tensor_names=TENSOR_NAMES,
8686
model_class=SmolLM,
8787
)
88+
89+
90+
class SmolLM2(model_builder.DecoderOnlyModel):
91+
"""A SmolLM2 model built from the Edge Generative API layers."""
92+
pass
93+
94+
95+
def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
96+
"""Returns the model config for a SmolLM2 135M model.
97+
98+
Args:
99+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
100+
is 1024.
101+
102+
Returns:
103+
The model config for a SmolLM2 model.
104+
"""
105+
config = get_model_config(kv_cache_max_len)
106+
config.block_config(0).attn_config.rotary_base = 100000
107+
return config
108+
109+
110+
def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
111+
config = get_model_config_v2(**kwargs)
112+
config.vocab_size = 128
113+
config.num_layers = 2
114+
# SmolLM2 has only one block config.
115+
config.block_config(0).ff_config.intermediate_size = 64
116+
return config
117+
118+
119+
def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
120+
return model_builder.build_decoder_only_model(
121+
checkpoint_path=checkpoint_path,
122+
config=get_model_config_v2(**kwargs),
123+
tensor_names=TENSOR_NAMES,
124+
model_class=SmolLM2,
125+
)

ai_edge_torch/generative/examples/smollm/verify.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,26 @@
3636
30,
3737
"The maximum size of the generated tokens.",
3838
)
39+
_MODEL_VERSION = flags.DEFINE_enum(
40+
"model_version",
41+
"v1",
42+
["v1", "v2"],
43+
"The version of SmolLm to verify.",
44+
)
45+
_CHECKPOINT = {
46+
"v1": "HuggingFaceTB/SmolLM-135M",
47+
"v2": "HuggingFaceTB/SmolLM2-135M",
48+
}
49+
50+
_BUILDER = {
51+
"v1": smollm.build_model,
52+
"v2": smollm.build_model_v2,
53+
}
3954

4055

4156
def main(_):
42-
checkpoint = "HuggingFaceTB/SmolLM-135M"
57+
checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
58+
builder = _BUILDER[_MODEL_VERSION.value]
4359
logging.info("Loading the original model from: %s", checkpoint)
4460
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
4561

@@ -49,7 +65,7 @@ def main(_):
4965
)
5066
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
5167
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52-
reauthored_model = smollm.build_model(reauthored_checkpoint)
68+
reauthored_model = builder(reauthored_checkpoint)
5369

5470
logging.info("Loading the tokenizer from: %s", checkpoint)
5571
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)

ai_edge_torch/generative/test/test_model_conversion_large.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ def test_smollm(self):
150150
ai_edge_torch.config.in_oss,
151151
reason="tests with custom ops are not supported in oss",
152152
)
153+
154+
def test_smollm2(self):
155+
config = smollm.get_fake_model_config_v2()
156+
pytorch_model = smollm.SmolLM2(config).eval()
157+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158+
@googletest.skipIf(
159+
ai_edge_torch.config.in_oss,
160+
reason="tests with custom ops are not supported in oss",
161+
)
162+
153163
def test_openelm(self):
154164
config = openelm.get_fake_model_config()
155165
pytorch_model = openelm.OpenELM(config).eval()

0 commit comments

Comments
 (0)