Skip to content

Commit 18fb095

Browse files
[Model builder] Add support for Ernie 4.5 models (#1608)
Enables exporting the new Ernie 4.5 models via onnxruntime-genai: https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT I've uploaded the converted model to https://huggingface.co/onnx-community/ERNIE-4.5-0.3B-ONNX. Currently only supports the non-MoE version... but maybe someone can help with the MoE version: https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT --- Models tested and validated with python ort & [transformers.js](huggingface/transformers.js#1354): ```py from transformers import AutoConfig, AutoTokenizer import onnxruntime import numpy as np # 1. Load config, processor, and model path_to_model = "./path/to/model" config = AutoConfig.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", trust_remote_code=True) decoder_session = onnxruntime.InferenceSession(f"{path_to_model}/model.onnx") ## Set config values num_key_value_heads = config.num_key_value_heads head_dim = config.head_dim num_hidden_layers = config.num_hidden_layers eos_token_id = config.eos_token_id # 2. Prepare inputs ## Create input messages messages = [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "Write me a poem about Machine Learning." }, ] ## Apply tokenizer inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np") ## Prepare decoder inputs batch_size = inputs['input_ids'].shape[0] past_key_values = { f'past_key_values.{layer}.{kv}': np.zeros([batch_size, num_key_value_heads, 0, head_dim], dtype=np.float32) for layer in range(num_hidden_layers) for kv in ('key', 'value') } input_ids = inputs['input_ids'] position_ids = np.tile(np.arange(1, input_ids.shape[-1] + 1), (batch_size, 1)) attention_mask = np.ones_like(input_ids, dtype=np.int64) # 3. Generation loop max_new_tokens = 1024 generated_tokens = np.array([[]], dtype=np.int64) for i in range(max_new_tokens): logits, *present_key_values = decoder_session.run(None, dict( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, **past_key_values, )) ## Update values for next generation loop input_ids = logits[:, -1].argmax(-1, keepdims=True) attention_mask = np.concatenate([attention_mask, np.ones_like(input_ids, dtype=np.int64)], axis=-1) position_ids = position_ids[:, -1:] + 1 for j, key in enumerate(past_key_values): past_key_values[key] = present_key_values[j] generated_tokens = np.concatenate([generated_tokens, input_ids], axis=-1) if (input_ids == eos_token_id).all(): break ## (Optional) Streaming print(tokenizer.decode(input_ids[0]), end='', flush=True) print() # 4. Output result print(tokenizer.batch_decode(generated_tokens)) ``` --------- Co-authored-by: kunal-vaishnavi <[email protected]>
1 parent 2f2ad90 commit 18fb095

File tree

4 files changed

+22
-4
lines changed

4 files changed

+22
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ See documentation at https://onnxruntime.ai/docs/genai.
2020

2121
|Support matrix|Supported now|Under development|On the roadmap|
2222
| -------------- | ------------- | ----------------- | -------------- |
23-
| Model architectures | DeepSeek <br/> Gemma <br/> Llama * <br/> Mistral + <br/> Phi (language + vision) <br/> Qwen <br/> Nemotron <br/> Granite <br/> AMD OLMo | Whisper | Stable diffusion |
23+
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Gemma <br/> Granite <br/> Llama * <br/> Mistral + <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen | Whisper | Stable diffusion |
2424
|API| Python <br/>C# <br/>C/C++ <br/> Java ^ |Objective-C||
2525
|Platform| Linux <br/> Windows <br/>Mac ^ <br/>Android ^ ||iOS |||
2626
|Architecture|x86 <br/> x64 <br/> Arm64 ~ ||||

src/models/model_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace Generators {
1212
struct ModelType {
1313
inline static bool IsLLM(const std::string& model_type) {
1414
// Large-language model (LLM)
15-
static constexpr std::array<std::string_view, 17> LLM = {"chatglm", "decoder", "gemma", "gemma2", "gemma3_text", "gpt2", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3"};
15+
static constexpr std::array<std::string_view, 18> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3"};
1616
return std::find(LLM.begin(), LLM.end(), model_type) != LLM.end();
1717
}
1818

src/python/py/models/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,17 @@ This folder contains the model builder for quickly creating optimized and quanti
3131

3232
The tool currently supports the following model architectures.
3333

34+
- AMD OLMo
3435
- ChatGLM
36+
- DeepSeek
37+
- ERNIE 4.5
3538
- Gemma
3639
- Granite
37-
- LLaMA
40+
- Llama
3841
- Mistral
3942
- Nemotron
4043
- Phi
4144
- Qwen
42-
- AMD OLMo
4345

4446
It is intended for supporting the latest, popular state-of-the-art models.
4547

src/python/py/models/builder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3619,6 +3619,20 @@ def make_rotary_embedding_caches(self, **kwargs):
36193619
return super().make_rotary_embedding_caches(cos_cache_name=cos_cache_name, sin_cache_name=sin_cache_name)
36203620

36213621

3622+
class ErnieModel(MistralModel):
3623+
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
3624+
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
3625+
3626+
# Ernie uses interleaved rotary position embeddings.
3627+
self.rotemb_attrs["interleaved"] = 1
3628+
3629+
# Ernie uses a `compression_ratio` for its RoPE scaling.
3630+
# The original RoPE logic in ernie is: position_ids / compression_ratio,
3631+
# which is equivalent to scaling the frequencies (inv_freq) by 1 / compression_ratio.
3632+
if hasattr(config, "compression_ratio") and config.compression_ratio != 1.0:
3633+
self.rotemb_attrs["rescale_factors"] = 1.0 / config.compression_ratio
3634+
3635+
36223636
def check_extra_options(kv_pairs):
36233637
"""
36243638
Check key-value pairs and set values correctly
@@ -3739,6 +3753,8 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
37393753
# Quantized ChatGLM model has ChatGLMForConditionalGeneration as architecture whereas HF model as the latter
37403754
config.hidden_act = "swiglu"
37413755
onnx_model = ChatGLMModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
3756+
elif config.architectures[0] == "Ernie4_5_ForCausalLM":
3757+
onnx_model = ErnieModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
37423758
elif config.architectures[0] == "GemmaForCausalLM":
37433759
onnx_model = GemmaModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
37443760
elif config.architectures[0] == "Gemma2ForCausalLM":

0 commit comments

Comments
 (0)