Commit 18fb095
[Model builder] Add support for Ernie 4.5 models (#1608)
Enables exporting the new Ernie 4.5 models via onnxruntime-genai:
https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT
I've uploaded the converted model to
https://huggingface.co/onnx-community/ERNIE-4.5-0.3B-ONNX.
Currently only supports the non-MoE version... but maybe someone can
help with the MoE version:
https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT
---
Models tested and validated with python ort &
[transformers.js](huggingface/transformers.js#1354):
```py
from transformers import AutoConfig, AutoTokenizer
import onnxruntime
import numpy as np
# 1. Load config, processor, and model
path_to_model = "./path/to/model"
config = AutoConfig.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", trust_remote_code=True)
decoder_session = onnxruntime.InferenceSession(f"{path_to_model}/model.onnx")
## Set config values
num_key_value_heads = config.num_key_value_heads
head_dim = config.head_dim
num_hidden_layers = config.num_hidden_layers
eos_token_id = config.eos_token_id
# 2. Prepare inputs
## Create input messages
messages = [
{ "role": "system", "content": "You are a helpful assistant." },
{ "role": "user", "content": "Write me a poem about Machine Learning." },
]
## Apply tokenizer
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np")
## Prepare decoder inputs
batch_size = inputs['input_ids'].shape[0]
past_key_values = {
f'past_key_values.{layer}.{kv}': np.zeros([batch_size, num_key_value_heads, 0, head_dim], dtype=np.float32)
for layer in range(num_hidden_layers)
for kv in ('key', 'value')
}
input_ids = inputs['input_ids']
position_ids = np.tile(np.arange(1, input_ids.shape[-1] + 1), (batch_size, 1))
attention_mask = np.ones_like(input_ids, dtype=np.int64)
# 3. Generation loop
max_new_tokens = 1024
generated_tokens = np.array([[]], dtype=np.int64)
for i in range(max_new_tokens):
logits, *present_key_values = decoder_session.run(None, dict(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**past_key_values,
))
## Update values for next generation loop
input_ids = logits[:, -1].argmax(-1, keepdims=True)
attention_mask = np.concatenate([attention_mask, np.ones_like(input_ids, dtype=np.int64)], axis=-1)
position_ids = position_ids[:, -1:] + 1
for j, key in enumerate(past_key_values):
past_key_values[key] = present_key_values[j]
generated_tokens = np.concatenate([generated_tokens, input_ids], axis=-1)
if (input_ids == eos_token_id).all():
break
## (Optional) Streaming
print(tokenizer.decode(input_ids[0]), end='', flush=True)
print()
# 4. Output result
print(tokenizer.batch_decode(generated_tokens))
```
---------
Co-authored-by: kunal-vaishnavi <[email protected]>1 parent 2f2ad90 commit 18fb095
File tree
4 files changed
+22
-4
lines changed- src
- models
- python/py/models
4 files changed
+22
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
23 | | - | |
| 23 | + | |
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | | - | |
| 15 | + | |
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
| 34 | + | |
34 | 35 | | |
| 36 | + | |
| 37 | + | |
35 | 38 | | |
36 | 39 | | |
37 | | - | |
| 40 | + | |
38 | 41 | | |
39 | 42 | | |
40 | 43 | | |
41 | 44 | | |
42 | | - | |
43 | 45 | | |
44 | 46 | | |
45 | 47 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3619 | 3619 | | |
3620 | 3620 | | |
3621 | 3621 | | |
| 3622 | + | |
| 3623 | + | |
| 3624 | + | |
| 3625 | + | |
| 3626 | + | |
| 3627 | + | |
| 3628 | + | |
| 3629 | + | |
| 3630 | + | |
| 3631 | + | |
| 3632 | + | |
| 3633 | + | |
| 3634 | + | |
| 3635 | + | |
3622 | 3636 | | |
3623 | 3637 | | |
3624 | 3638 | | |
| |||
3739 | 3753 | | |
3740 | 3754 | | |
3741 | 3755 | | |
| 3756 | + | |
| 3757 | + | |
3742 | 3758 | | |
3743 | 3759 | | |
3744 | 3760 | | |
| |||
0 commit comments