Skip to content

Commit 76c56b0

Browse files
authored
support ERNIE (#4757)
1 parent d066874 commit 76c56b0

File tree

15 files changed

+98
-17
lines changed

15 files changed

+98
-17
lines changed

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,12 @@
560560
|[rednote-hilab/dots.llm1.base](https://modelscope.cn/models/rednote-hilab/dots.llm1.base)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base)|
561561
|[rednote-hilab/dots.llm1.inst](https://modelscope.cn/models/rednote-hilab/dots.llm1.inst)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.inst](https://huggingface.co/rednote-hilab/dots.llm1.inst)|
562562
|[Tencent-Hunyuan/Hunyuan-A13B-Instruct](https://modelscope.cn/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct)|hunyuan|hunyuan|-|✘|-|[tencent/Hunyuan-A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct)|
563+
|[PaddlePaddle/ERNIE-4.5-0.3B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)|
564+
|[PaddlePaddle/ERNIE-4.5-0.3B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)|
565+
|[PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-21B-A3B-Base-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Base-PT)|
566+
|[PaddlePaddle/ERNIE-4.5-21B-A3B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-21B-A3B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT)|
567+
|[PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-300B-A47B-Base-PT](https://huggingface.co/baidu/ERNIE-4.5-300B-A47B-Base-PT)|
568+
|[PaddlePaddle/ERNIE-4.5-300B-A47B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-300B-A47B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-300B-A47B-PT](https://huggingface.co/baidu/ERNIE-4.5-300B-A47B-PT)|
563569
|[answerdotai/ModernBERT-base](https://modelscope.cn/models/answerdotai/ModernBERT-base)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)|
564570
|[answerdotai/ModernBERT-large](https://modelscope.cn/models/answerdotai/ModernBERT-large)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)|
565571
|[iic/gte-modernbert-base](https://modelscope.cn/models/iic/gte-modernbert-base)|modern_bert_gte|dummy|transformers>=4.48|✘|bert, embedding|[Alibaba-NLP/gte-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-modernbert-base)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,12 @@ The table below introduces the models integrated with ms-swift:
560560
|[rednote-hilab/dots.llm1.base](https://modelscope.cn/models/rednote-hilab/dots.llm1.base)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base)|
561561
|[rednote-hilab/dots.llm1.inst](https://modelscope.cn/models/rednote-hilab/dots.llm1.inst)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.inst](https://huggingface.co/rednote-hilab/dots.llm1.inst)|
562562
|[Tencent-Hunyuan/Hunyuan-A13B-Instruct](https://modelscope.cn/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct)|hunyuan|hunyuan|-|✘|-|[tencent/Hunyuan-A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct)|
563+
|[PaddlePaddle/ERNIE-4.5-0.3B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)|
564+
|[PaddlePaddle/ERNIE-4.5-0.3B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)|
565+
|[PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-21B-A3B-Base-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Base-PT)|
566+
|[PaddlePaddle/ERNIE-4.5-21B-A3B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-21B-A3B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT)|
567+
|[PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-300B-A47B-Base-PT](https://huggingface.co/baidu/ERNIE-4.5-300B-A47B-Base-PT)|
568+
|[PaddlePaddle/ERNIE-4.5-300B-A47B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-300B-A47B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-300B-A47B-PT](https://huggingface.co/baidu/ERNIE-4.5-300B-A47B-PT)|
563569
|[answerdotai/ModernBERT-base](https://modelscope.cn/models/answerdotai/ModernBERT-base)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)|
564570
|[answerdotai/ModernBERT-large](https://modelscope.cn/models/answerdotai/ModernBERT-large)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)|
565571
|[iic/gte-modernbert-base](https://modelscope.cn/models/iic/gte-modernbert-base)|modern_bert_gte|dummy|transformers>=4.48|✘|bert, embedding|[Alibaba-NLP/gte-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-modernbert-base)|

swift/llm/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class LLMModelType:
120120
mimo_rl = 'mimo_rl'
121121
dots1 = 'dots1'
122122
hunyuan = 'hunyuan'
123+
ernie = 'ernie'
123124

124125

125126
class BertModelType:

swift/llm/model/model/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from . import (baai, baichuan, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, microsoft,
2-
minicpm, minimax, mistral, mllm, moonshot, mplug, openbuddy, qwen, skywork, stepfun, telechat, valley,
3-
yi)
1+
from . import (baai, baichuan, baidu, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba,
2+
microsoft, minicpm, minimax, mistral, mllm, moonshot, mplug, openbuddy, qwen, skywork, stepfun, telechat,
3+
valley, yi)

swift/llm/model/model/baidu.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from swift.llm import TemplateType
3+
from swift.utils import get_logger
4+
from ..constant import LLMModelType
5+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
6+
7+
logger = get_logger()
8+
9+
register_model(
10+
ModelMeta(
11+
LLMModelType.ernie,
12+
[
13+
ModelGroup([
14+
Model('PaddlePaddle/ERNIE-4.5-0.3B-Base-PT', 'baidu/ERNIE-4.5-0.3B-PT'),
15+
Model('PaddlePaddle/ERNIE-4.5-0.3B-PT', 'baidu/ERNIE-4.5-0.3B-PT'),
16+
]),
17+
ModelGroup([
18+
Model('PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT', 'baidu/ERNIE-4.5-21B-A3B-Base-PT'),
19+
Model('PaddlePaddle/ERNIE-4.5-21B-A3B-PT', 'baidu/ERNIE-4.5-21B-A3B-PT'),
20+
Model('PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT', 'baidu/ERNIE-4.5-300B-A47B-Base-PT'),
21+
Model('PaddlePaddle/ERNIE-4.5-300B-A47B-PT', 'baidu/ERNIE-4.5-300B-A47B-PT'),
22+
]),
23+
],
24+
TemplateType.ernie,
25+
get_model_tokenizer_with_flash_attn,
26+
architectures=['Ernie4_5_ForCausalLM', 'Ernie4_5_MoeForCausalLM'],
27+
))

swift/llm/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class LLMTemplateType:
8686
mimo_rl = 'mimo_rl'
8787
dots1 = 'dots1'
8888
hunyuan = 'hunyuan'
89+
ernie = 'ernie'
8990

9091
aya = 'aya'
9192
c4ai = 'c4ai'
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from . import (bert, deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft,
2-
minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi)
1+
from . import (baidu, bert, deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez,
2+
microsoft, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley,
3+
yi)

swift/llm/template/template/baidu.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from dataclasses import dataclass, field
3+
from typing import Optional
4+
5+
from ..constant import LLMTemplateType
6+
from ..register import TemplateMeta, register_template
7+
from ..utils import Prompt
8+
9+
10+
@dataclass
11+
class ERNIETemplateMeta(TemplateMeta):
12+
prefix: Prompt = field(default_factory=lambda: ['<|begin_of_sentence|>'])
13+
prompt: Prompt = field(default_factory=lambda: ['User: {{QUERY}}\nAssistant: '])
14+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end_of_sentence|>'])
15+
suffix: Prompt = field(default_factory=lambda: ['</s>'])
16+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|begin_of_sentence|>{{SYSTEM}}\n'])
17+
18+
19+
register_template(ERNIETemplateMeta(LLMTemplateType.ernie))

swift/llm/template/template/deepseek.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
import os
32
from dataclasses import dataclass, field
43
from typing import Any, Dict, List, Optional
54

swift/megatron/argument/megatron_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class MegatronArguments(ExtraMegatronArguments):
135135
position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none'] = 'rope'
136136
rotary_base: Optional[int] = None
137137
rotary_percent: float = 1.
138+
rotary_interleaved: Optional[bool] = None
138139
normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm'
139140
norm_epsilon: Optional[float] = None
140141
swiglu: Optional[bool] = None
@@ -228,6 +229,8 @@ def _set_default(self):
228229
self.norm_epsilon = 1e-5
229230
if self.rotary_base is None:
230231
self.rotary_base = 10000
232+
if self.rotary_interleaved is None:
233+
self.rotary_interleaved = False
231234
if self.attention_dropout is None:
232235
self.attention_dropout = 0.
233236
if self.untie_embeddings_and_output_weights is None:

0 commit comments

Comments
 (0)