Skip to content

Commit fcc1c2d

Browse files
authored
[model] support Kimi-K2 template (#4925)
1 parent 3f09697 commit fcc1c2d

File tree

4 files changed

+32
-6
lines changed

4 files changed

+32
-6
lines changed

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@
552552
|[AI-ModelScope/aya-expanse-32b](https://modelscope.cn/models/AI-ModelScope/aya-expanse-32b)|aya|aya|transformers>=4.44.0|✘|-|[CohereForAI/aya-expanse-32b](https://huggingface.co/CohereForAI/aya-expanse-32b)|
553553
|[moonshotai/Moonlight-16B-A3B](https://modelscope.cn/models/moonshotai/Moonlight-16B-A3B)|moonlight|moonlight|transformers<4.49|&#x2714;|-|[moonshotai/Moonlight-16B-A3B](https://huggingface.co/moonshotai/Moonlight-16B-A3B)|
554554
|[moonshotai/Moonlight-16B-A3B-Instruct](https://modelscope.cn/models/moonshotai/Moonlight-16B-A3B-Instruct)|moonlight|moonlight|transformers<4.49|&#x2714;|-|[moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct)|
555+
|[moonshotai/Kimi-K2-Base](https://modelscope.cn/models/moonshotai/Kimi-K2-Base)|moonlight|moonlight|transformers<4.49|&#x2714;|-|[moonshotai/Kimi-K2-Base](https://huggingface.co/moonshotai/Kimi-K2-Base)|
556+
|[moonshotai/Kimi-K2-Instruct](https://modelscope.cn/models/moonshotai/Kimi-K2-Instruct)|moonlight|moonlight|transformers<4.49|&#x2714;|-|[moonshotai/Kimi-K2-Instruct](https://huggingface.co/moonshotai/Kimi-K2-Instruct)|
555557
|[XiaomiMiMo/MiMo-7B-Base](https://modelscope.cn/models/XiaomiMiMo/MiMo-7B-Base)|mimo|qwen|transformers>=4.37|&#x2714;|-|[XiaomiMiMo/MiMo-7B-Base](https://huggingface.co/XiaomiMiMo/MiMo-7B-Base)|
556558
|[XiaomiMiMo/MiMo-7B-SFT](https://modelscope.cn/models/XiaomiMiMo/MiMo-7B-SFT)|mimo|qwen|transformers>=4.37|&#x2714;|-|[XiaomiMiMo/MiMo-7B-SFT](https://huggingface.co/XiaomiMiMo/MiMo-7B-SFT)|
557559
|[XiaomiMiMo/MiMo-7B-RL-Zero](https://modelscope.cn/models/XiaomiMiMo/MiMo-7B-RL-Zero)|mimo|qwen|transformers>=4.37|&#x2714;|-|[XiaomiMiMo/MiMo-7B-RL-Zero](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL-Zero)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ The table below introduces the models integrated with ms-swift:
552552
|[AI-ModelScope/aya-expanse-32b](https://modelscope.cn/models/AI-ModelScope/aya-expanse-32b)|aya|aya|transformers>=4.44.0|&#x2718;|-|[CohereForAI/aya-expanse-32b](https://huggingface.co/CohereForAI/aya-expanse-32b)|
553553
|[moonshotai/Moonlight-16B-A3B](https://modelscope.cn/models/moonshotai/Moonlight-16B-A3B)|moonlight|moonlight|transformers<4.49|&#x2714;|-|[moonshotai/Moonlight-16B-A3B](https://huggingface.co/moonshotai/Moonlight-16B-A3B)|
554554
|[moonshotai/Moonlight-16B-A3B-Instruct](https://modelscope.cn/models/moonshotai/Moonlight-16B-A3B-Instruct)|moonlight|moonlight|transformers<4.49|&#x2714;|-|[moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct)|
555+
|[moonshotai/Kimi-K2-Base](https://modelscope.cn/models/moonshotai/Kimi-K2-Base)|moonlight|moonlight|transformers<4.49|&#x2714;|-|[moonshotai/Kimi-K2-Base](https://huggingface.co/moonshotai/Kimi-K2-Base)|
556+
|[moonshotai/Kimi-K2-Instruct](https://modelscope.cn/models/moonshotai/Kimi-K2-Instruct)|moonlight|moonlight|transformers<4.49|&#x2714;|-|[moonshotai/Kimi-K2-Instruct](https://huggingface.co/moonshotai/Kimi-K2-Instruct)|
555557
|[XiaomiMiMo/MiMo-7B-Base](https://modelscope.cn/models/XiaomiMiMo/MiMo-7B-Base)|mimo|qwen|transformers>=4.37|&#x2714;|-|[XiaomiMiMo/MiMo-7B-Base](https://huggingface.co/XiaomiMiMo/MiMo-7B-Base)|
556558
|[XiaomiMiMo/MiMo-7B-SFT](https://modelscope.cn/models/XiaomiMiMo/MiMo-7B-SFT)|mimo|qwen|transformers>=4.37|&#x2714;|-|[XiaomiMiMo/MiMo-7B-SFT](https://huggingface.co/XiaomiMiMo/MiMo-7B-SFT)|
557559
|[XiaomiMiMo/MiMo-7B-RL-Zero](https://modelscope.cn/models/XiaomiMiMo/MiMo-7B-RL-Zero)|mimo|qwen|transformers>=4.37|&#x2714;|-|[XiaomiMiMo/MiMo-7B-RL-Zero](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL-Zero)|

swift/llm/model/model/moonshot.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
Model('moonshotai/Moonlight-16B-A3B', 'moonshotai/Moonlight-16B-A3B'),
1515
Model('moonshotai/Moonlight-16B-A3B-Instruct', 'moonshotai/Moonlight-16B-A3B-Instruct'),
1616
]),
17+
ModelGroup([
18+
Model('moonshotai/Kimi-K2-Base', 'moonshotai/Kimi-K2-Base'),
19+
Model('moonshotai/Kimi-K2-Instruct', 'moonshotai/Kimi-K2-Instruct'),
20+
]),
1721
],
1822
TemplateType.moonlight,
1923
get_model_tokenizer_with_flash_attn,

swift/megatron/utils/convert.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from megatron.training.initialize import initialize_megatron
1212
from megatron.training.utils import get_ltor_masks_and_position_ids
1313

14-
from swift.llm import ExportArguments, HfConfigFactory, get_model_tokenizer, get_template, save_checkpoint
14+
from swift.llm import ExportArguments, HfConfigFactory, get_model_tokenizer, get_template, save_checkpoint, to_device
1515
from swift.utils import get_logger, get_n_params_grads
1616
from ..argument import MegatronArguments
1717
from ..model import get_megatron_model_meta
@@ -87,21 +87,37 @@ def test_convert_precision(hf_model, mg_model, processor, torch_dtype=torch.floa
8787
_test_params_sum(mg_model)
8888

8989
template = get_template(hf_model.model_meta.template, processor)
90-
input_ids = template.encode({'messages': [{'role': 'user', 'content': 'who are you?'}]})['input_ids']
91-
input_ids = torch.tensor(input_ids)[None].to('cuda')
90+
template.set_mode('train')
91+
inputs = template.encode({
92+
'messages': [
93+
{
94+
'role': 'user',
95+
'content': 'Introduction to ms-swift.'
96+
},
97+
{
98+
'role':
99+
'assistant',
100+
'content':
101+
'ms-swift is an official framework provided by the ModelScope community for fine-tuning '
102+
'and deploying large language models and multi-modal large models.'
103+
},
104+
]
105+
})
106+
inputs = to_device(template.data_collator([inputs]), 'cuda')
92107

93108
HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False)
94109
share_embedding = mg_model.share_embeddings_and_output_weights
95110
hf_modules = _find_modules(hf_model)
96111
with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding):
97-
hf_logits = hf_model(input_ids).logits
112+
hf_logits = hf_model(**inputs).logits
98113
hf_model = hf_model.to('cpu')
99114

115+
input_ids = inputs['input_ids']
100116
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True)
101117
packed_seq_params = None
102118
mg_torch_dtype = torch_dtype
103119
# thd
104-
# from ..train.utils import get_packed_seq_params
120+
# from ..trainers.utils import get_packed_seq_params
105121
# mg_torch_dtype = None
106122
# packed_seq_params = get_packed_seq_params(position_ids)
107123
# attention_mask = None
@@ -115,8 +131,10 @@ def test_convert_precision(hf_model, mg_model, processor, torch_dtype=torch.floa
115131
position_ids=position_ids,
116132
packed_seq_params=packed_seq_params)
117133

118-
mean_diff = (mg_logits - hf_logits).abs().mean().item()
134+
token_mean_diff = (mg_logits - hf_logits).abs().mean(dim=-1)
135+
mean_diff = token_mean_diff.mean().item()
119136
max_diff = (mg_logits - hf_logits).abs().max().item()
137+
print(f'token_mean_diff: {token_mean_diff}')
120138
print(f'mean_diff: {mean_diff}, max_diff: {max_diff} (Please check that mean_diff is less than 0.1).')
121139
hf_tokens = hf_logits.argmax(-1)
122140
mg_tokens = mg_logits.argmax(-1)

0 commit comments

Comments
 (0)