-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[model] support olmoe #7140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[model] support olmoe #7140
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
29342d0
add new model: olmoe
qianhao0713 f812a8c
Merge branch 'modelscope:main' into main
qianhao0713 1621950
Merge branch 'main' into main
qianhao0713 64cdd03
fix _set_attn_state in OLMoEBridge
qianhao0713 1da05ca
Merge branch 'main' into main
qianhao0713 4543a73
correct template for olmoe version 0924
qianhao0713 c8febdf
remove unused varibales in OLMoEBridge._set_attn_state
qianhao0713 db97467
fix template for olmoe_0924 and support olmoe_0924 megatron training
qianhao0713 ad89de6
add test_cases for olmoe and fix bugs for mcore 0.15
qianhao0713 b228e40
resolve conflict
qianhao0713 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| class LLMMegatronModelType: | ||
| gpt = 'gpt' | ||
| qwen3_next = 'qwen3_next' | ||
| olmoe = 'olmoe' | ||
| glm4 = 'glm4' | ||
|
|
||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,242 @@ | ||
| from copy import deepcopy | ||
| from typing import Optional | ||
|
|
||
| import megatron.core | ||
| import torch | ||
| import torch.distributed as dist | ||
| from megatron.core.extensions.transformer_engine import SplitAlongDim, TENorm | ||
| from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec | ||
| from megatron.core.process_groups_config import ModelCommProcessGroups | ||
| from megatron.core.transformer.attention import SelfAttention as SelfAttentionBase | ||
| from megatron.core.transformer.attention import SelfAttentionSubmodules | ||
| from megatron.core.transformer.enums import AttnMaskType, LayerType | ||
| from megatron.core.transformer.spec_utils import build_module | ||
| from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build | ||
| from megatron.core.transformer.transformer_config import TransformerConfig | ||
| from megatron.core.transformer.transformer_layer import get_transformer_layer_offset | ||
| from packaging import version | ||
|
|
||
| from swift.llm import ModelType | ||
| from swift.megatron.tuners import LoraParallelLinear | ||
| from ..constant import MegatronModelType | ||
| from ..gpt_bridge import GPTBridge | ||
| from ..register import MegatronModelMeta, register_megatron_model | ||
|
|
||
| mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') | ||
|
|
||
|
|
||
| class OLMoESelfAttention(SelfAttentionBase): | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: TransformerConfig, | ||
| submodules: SelfAttentionSubmodules, | ||
| layer_number: int, | ||
| attn_mask_type=AttnMaskType.padding, | ||
| cp_comm_type: str = None, | ||
| model_comm_pgs: ModelCommProcessGroups = None, | ||
| ): | ||
| super().__init__( | ||
| config=config, | ||
| submodules=submodules, | ||
| layer_number=layer_number, | ||
| attn_mask_type=attn_mask_type, | ||
| cp_comm_type=cp_comm_type, | ||
| model_comm_pgs=model_comm_pgs, | ||
| ) | ||
| self.q_layernorm = build_module( | ||
| submodules.q_layernorm, | ||
| hidden_size=self.hidden_size_per_attention_head * self.num_attention_heads_per_partition, | ||
| config=self.config, | ||
| eps=self.config.layernorm_epsilon, | ||
| ) | ||
| self.k_layernorm = build_module( | ||
| submodules.k_layernorm, | ||
| hidden_size=self.hidden_size_per_attention_head * self.num_query_groups_per_partition, | ||
| config=self.config, | ||
| eps=self.config.layernorm_epsilon, | ||
| ) | ||
|
|
||
| def get_query_key_value_tensors(self, hidden_states, key_value_states=None): | ||
| """ | ||
| Derives `query`, `key` and `value` tensors from `hidden_states`. | ||
| """ | ||
| # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] | ||
| mixed_qkv, _ = self.linear_qkv(hidden_states) | ||
|
|
||
| # [sq, b, ng * (np/ng + 2) * hn] -> [sq, b, np * hn], [sq, b, ng * hn], [sq, b, ng * hn] | ||
| split_arg_list = [ | ||
| self.hidden_size_per_attention_head * self.num_attention_heads_per_partition, | ||
| self.hidden_size_per_attention_head * self.num_query_groups_per_partition, | ||
| self.hidden_size_per_attention_head * self.num_query_groups_per_partition | ||
| ] | ||
|
|
||
| if SplitAlongDim is not None: | ||
| (query, key, value) = SplitAlongDim(mixed_qkv, 2, split_arg_list) | ||
| else: | ||
| (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=2) | ||
|
|
||
| if self.q_layernorm is not None: | ||
| query = self.q_layernorm(query) | ||
|
|
||
| if self.k_layernorm is not None: | ||
| key = self.k_layernorm(key) | ||
|
|
||
| query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) | ||
| key = key.reshape(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) | ||
| value = value.reshape(value.size(0), value.size(1), -1, self.hidden_size_per_attention_head) | ||
|
|
||
| if self.config.test_mode: | ||
| self.run_realtime_tests() | ||
|
|
||
| return query, key, value | ||
|
|
||
|
|
||
| def get_olmoe_decoder_block_spec( | ||
| config: TransformerConfig, | ||
| vp_stage: Optional[int] = None, | ||
| ) -> TransformerBlockSubmodules: | ||
| """GPT block spec.""" | ||
| layer_norm_impl = TENorm | ||
| kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {} | ||
| moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( | ||
| num_experts=config.num_moe_experts, | ||
| moe_grouped_gemm=config.moe_grouped_gemm, | ||
| qk_layernorm=True, | ||
| multi_latent_attention=False, | ||
| moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, | ||
| **kwargs, | ||
| ) | ||
| layer_specs = [] | ||
| for _ in range(config.num_layers): | ||
| layer_spec = deepcopy(moe_layer_spec) | ||
| layer_spec.submodules.self_attention.module = OLMoESelfAttention | ||
| layer_specs.append(layer_spec) | ||
|
|
||
| num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) | ||
|
|
||
| if config.pipeline_model_parallel_layout is not None: | ||
| local_layer_specs = [ | ||
| layer_specs[layer_id] for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list( | ||
| layer_type=LayerType.decoder, vp_stage=vp_stage) | ||
| ] | ||
| else: | ||
| offset = get_transformer_layer_offset(config, vp_stage=vp_stage) | ||
| local_layer_specs = layer_specs[offset:offset + num_layers_to_build] | ||
|
|
||
| # Block spec. | ||
| block_spec = TransformerBlockSubmodules(layer_specs=local_layer_specs, layer_norm=layer_norm_impl) | ||
|
|
||
| return block_spec | ||
|
|
||
|
|
||
| class OLMoEBridge(GPTBridge): | ||
|
|
||
| def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): | ||
| if to_mcore: | ||
| hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) | ||
| else: | ||
| hf_state_dict = {} | ||
| hf_attn = self.hf_layers[layer_idx].self_attn | ||
| args = self.args | ||
| num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) | ||
|
||
| hidden_size_block = args.hidden_size // self.fp8_block_size | ||
| if to_mcore: | ||
| if isinstance(mg_attn.linear_qkv, LoraParallelLinear): | ||
| lora_A = hf_state_dict['q_proj.lora_A.weight'].load() | ||
| assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and ( | ||
| lora_A == hf_state_dict['v_proj.lora_A.weight'].load() | ||
| ).all(), 'Need to ensure QKV\'s lora_A are consistent' | ||
| lora_B = torch.cat([ | ||
| hf_state_dict['q_proj.lora_B.weight'].load(), | ||
| hf_state_dict['k_proj.lora_B.weight'].load(), | ||
| hf_state_dict['v_proj.lora_B.weight'].load(), | ||
| ], dim=0) | ||
| self._set_weight(mg_attn.linear_qkv.lora_A[self._adapter_name].weight, lora_A, | ||
| 'linear_qkv.lora_A.weight') | ||
| self._set_weight(mg_attn.linear_qkv.lora_B[self._adapter_name].weight, lora_B, | ||
| 'linear_qkv.lora_B.weight') | ||
| else: | ||
| linear_qkv_weight = torch.cat([ | ||
| hf_state_dict['q_proj.weight'].load(), | ||
| hf_state_dict['k_proj.weight'].load(), | ||
| hf_state_dict['v_proj.weight'].load(), | ||
| ], dim=0) | ||
| qkv_scale_inv = None | ||
| if 'q_proj.weight_scale_inv' in hf_state_dict: | ||
| qkv_scale_inv = torch.cat([ | ||
| hf_state_dict['q_proj.weight_scale_inv'].load(), | ||
| hf_state_dict['k_proj.weight_scale_inv'].load(), | ||
| hf_state_dict['v_proj.weight_scale_inv'].load(), | ||
| ], dim=0) | ||
| self._set_weight( | ||
| mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv) | ||
| else: | ||
| q_dim, kv_dim = hf_attn.q_proj.weight.shape[0], hf_attn.k_proj.weight.shape[0] | ||
| q_block = q_dim // self.fp8_block_size | ||
| kv_block = kv_dim // self.fp8_block_size | ||
| is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv, | ||
| LoraParallelLinear) and self._is_peft_format | ||
| is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') | ||
| if self.pp_size > 1: | ||
| dist.all_reduce(is_lora, group=self.pp_group) | ||
| if is_lora: | ||
| lora_A, _ = self._get_weight( | ||
| None if mg_attn is None else mg_attn.linear_qkv.lora_A[self._adapter_name].weight.data, | ||
| f'linear_qkv.lora_A.{self._adapter_name}.weight') | ||
| lora_B, _ = self._get_weight( | ||
| None if mg_attn is None else mg_attn.linear_qkv.lora_B[self._adapter_name].weight.data, | ||
| f'linear_qkv.lora_B.{self._adapter_name}.weight') | ||
| if lora_A is not None: | ||
| self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) | ||
| for key in ['q_proj', 'k_proj', 'v_proj']: | ||
| hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() | ||
| hf_state_dict['q_proj.lora_B.weight'] = lora_B[:q_dim, :].clone() | ||
| hf_state_dict['k_proj.lora_B.weight'] = lora_B[q_dim:-kv_dim, :].clone() | ||
| hf_state_dict['v_proj.lora_B.weight'] = lora_B[-kv_dim:, :].clone() | ||
| elif not self._is_peft_format: | ||
| mg_attn_weight, scale_inv = self._get_weight( | ||
| None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') | ||
| if mg_attn_weight is not None: | ||
| hf_state_dict['q_proj.weight'] = mg_attn_weight[:q_dim, :].clone() | ||
| hf_state_dict['k_proj.weight'] = mg_attn_weight[q_dim:-kv_dim, :].clone() | ||
| hf_state_dict['v_proj.weight'] = mg_attn_weight[-kv_dim:, :].clone() | ||
| if scale_inv is not None: | ||
| hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:q_block, :].clone() | ||
| hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[q_block:-kv_block, :].clone() | ||
| hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[-kv_block:, :].clone() | ||
| del mg_attn_weight | ||
| self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) | ||
| if args.add_qkv_bias and not self._is_peft_format: | ||
| if to_mcore: | ||
| linear_qkv_bias = torch.cat([ | ||
| hf_state_dict['q_proj.bias'].load(), | ||
| hf_state_dict['k_proj.bias'].load(), | ||
| hf_state_dict['v_proj.bias'].load(), | ||
| ], dim=0) | ||
| self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') | ||
| else: | ||
| mg_attn_bias, _ = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data, | ||
| 'linear_qkv.bias') | ||
| if mg_attn_bias is not None: | ||
| hf_state_dict['q_proj.bias'] = mg_attn_bias[:q_dim].clone() | ||
| hf_state_dict['k_proj.bias'] = mg_attn_bias[q_dim:-kv_dim].clone() | ||
| hf_state_dict['v_proj.bias'] = mg_attn_bias[-kv_dim:].clone() | ||
| if args.qk_layernorm: | ||
| hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight' | ||
| hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight' | ||
| self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, to_mcore) | ||
| self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, to_mcore) | ||
| if to_mcore: | ||
| hf_state_dict = {} | ||
| else: | ||
| hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) | ||
| return hf_state_dict | ||
|
|
||
| register_megatron_model( | ||
| MegatronModelMeta( | ||
| MegatronModelType.olmoe, | ||
| [ModelType.olmoe], | ||
| get_transformer_layer_spec=get_olmoe_decoder_block_spec, | ||
| bridge_cls=OLMoEBridge, | ||
| )) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The templates for
olmoeandolmoe_0924are not correctly implemented according to their respective Hugging Face model cards andtokenizer_config.jsonfiles.olmoetemplate incorrectly uses|||IP_ADDRESS|||as a separator. This token is used for PII redaction in the training data, not as a chat separator.<|endoftext|>token after each message part (system, user).olmoetemplate should have a newline\nafter<|endoftext|>, whileolmoe_0924should not, based on their respectivechat_templatedefinitions.Here is a corrected implementation for both templates that aligns with the official chat templates: