Skip to content

Commit 5c15fe2

Browse files
xenovakunal-vaishnaviLorenRd
authored
Add support for smollm3 (microsoft#1666)
cc @guschmue --------- Co-authored-by: kunal-vaishnavi <[email protected]> Co-authored-by: Lorenzo Rondán <[email protected]>
1 parent 18431d8 commit 5c15fe2

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ See documentation at https://onnxruntime.ai/docs/genai.
1616

1717
|Support matrix|Supported now|Under development|On the roadmap|
1818
| -------------- | ------------- | ----------------- | -------------- |
19-
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Gemma <br/> Granite <br/> Llama * <br/> Mistral + <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen | Whisper | Stable diffusion |
19+
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Gemma <br/> Granite <br/> Llama * <br/> Mistral + <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen <br/> SmolLM3 | Whisper | Stable diffusion |
2020
|API| Python <br/>C# <br/>C/C++ <br/> Java ^ |Objective-C||
2121
|Platform| Linux <br/> Windows <br/>Mac ^ <br/>Android ^ ||iOS |||
2222
|Architecture|x86 <br/> x64 <br/> Arm64 ~ ||||

src/models/model_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace Generators {
1212
struct ModelType {
1313
inline static bool IsLLM(const std::string& model_type) {
1414
// Large-language model (LLM)
15-
static constexpr std::array<std::string_view, 18> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3"};
15+
static constexpr std::array<std::string_view, 19> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
1616
return std::find(LLM.begin(), LLM.end(), model_type) != LLM.end();
1717
}
1818

src/python/py/models/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ The tool currently supports the following model architectures.
4242
- Nemotron
4343
- Phi
4444
- Qwen
45+
- SmolLM3
4546

4647
It is intended for supporting the latest, popular state-of-the-art models.
4748

src/python/py/models/builder.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3640,6 +3640,35 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
36403640
self.rotemb_attrs["rescale_factors"] = 1.0 / config.compression_ratio
36413641

36423642

3643+
class SmolLM3Model(LlamaModel):
3644+
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
3645+
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
3646+
self.layer_types = config.layer_types
3647+
self.no_rope_layers = config.no_rope_layers
3648+
3649+
def make_attention(self, layer_id, attention, root_input, **kwargs):
3650+
# SmolLM3 uses per-layer conditional RoPE and Sliding Window Attention.
3651+
# So, we temporarily modify the model's attributes before calling the
3652+
# base `make_attention` method, then restore them immediately after.
3653+
original_use_rope = self.attention_attrs["use_rope_in_attn"]
3654+
original_window_size = self.window_size
3655+
3656+
# Enable/disable RoPE for the current layer.
3657+
self.attention_attrs["use_rope_in_attn"] = bool(self.no_rope_layers[layer_id])
3658+
3659+
# Set the sliding window size for the current layer.
3660+
assert self.layer_types[layer_id] in {"sliding_attention", "full_attention"}
3661+
if self.layer_types[layer_id] == "full_attention":
3662+
self.window_size = -1
3663+
3664+
# Call the original `make_attention` with the temporarily-modified settings.
3665+
super().make_attention(layer_id, attention, root_input, **kwargs)
3666+
3667+
# Restore original values
3668+
self.attention_attrs["use_rope_in_attn"] = original_use_rope
3669+
self.window_size = original_window_size
3670+
3671+
36433672
def check_extra_options(kv_pairs):
36443673
"""
36453674
Check key-value pairs and set values correctly
@@ -3828,6 +3857,8 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
38283857
onnx_model = QwenModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
38293858
elif config.architectures[0] == "Qwen3ForCausalLM":
38303859
onnx_model = Qwen3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
3860+
elif config.architectures[0] == "SmolLM3ForCausalLM":
3861+
onnx_model = SmolLM3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
38313862
else:
38323863
raise NotImplementedError(f"The {hf_name} model is not currently supported.")
38333864

0 commit comments

Comments
 (0)