diff --git "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" index f6e520c0da..bec829723e 100644 --- "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -65,7 +65,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | torch | >=2.0 | 2.6.0/2.7.1 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | >=0.12 | 0.13 | | +| megatron_core | | 0.13 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.56.2 | | | modelscope | >=1.23 | | | diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 03b1558f3d..59ac888b65 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -65,7 +65,7 @@ Recommended Operating Environment: | torch | >=2.0 | 2.6.0/2.7.1 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | >=0.12 | 0.13 | | +| megatron_core | | 0.13 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.56.2 | | | modelscope | >=1.23 | | | diff --git a/swift/llm/model/patcher.py b/swift/llm/model/patcher.py index 1ab31aa8f7..58e2065f1e 100644 --- a/swift/llm/model/patcher.py +++ b/swift/llm/model/patcher.py @@ -467,7 +467,7 @@ def patch_tp_plan(load_model: bool): transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ: yield return - logger.info('Patch tp_plan.') + logger.info_once('Patch tp_plan.') WORLD_SIZE = os.environ.get('WORLD_SIZE') os.environ['_PATCH_WORLD_SIZE'] = WORLD_SIZE os.environ.pop('WORLD_SIZE') diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 0a5a41ebc1..a5a48f0897 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -13,7 +13,8 @@ if TYPE_CHECKING: from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main - from .utils import convert_hf2mcore, convert_mcore2hf, prepare_mcore_model, adapter_state_dict_context + from .convert import convert_hf2mcore, convert_mcore2hf + from .utils import prepare_mcore_model, adapter_state_dict_context from .argument import MegatronTrainArguments, MegatronRLHFArguments from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model from .trainers import MegatronTrainer, MegatronDPOTrainer @@ -21,7 +22,8 @@ else: _import_structure = { 'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'], - 'utils': ['convert_hf2mcore', 'convert_mcore2hf', 'prepare_mcore_model', 'adapter_state_dict_context'], + 'convert': ['convert_hf2mcore', 'convert_mcore2hf'], + 'utils': ['prepare_mcore_model', 'adapter_state_dict_context'], 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'], 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'], 'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'], diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 4cd2fc7a18..7b9a2efdd2 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -94,6 +94,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): torch_dtype: Optional[torch.dtype] = None padding_free: bool = True mlp_padding_free: bool = False + load_hf_checkpoint: bool = False + save_hf_checkpoint: bool = False # streaming dataloader dataloader_persistent_workers: bool = True dataloader_prefetch_factor: int = 10 diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 61535def15..cf4676bf97 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -9,6 +9,7 @@ from swift.llm.argument.base_args import to_abspath from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master from ..model import get_megatron_model_meta +from ..utils import convert_hf_config from .megatron_args import MegatronArguments logger = get_logger() @@ -23,7 +24,7 @@ def init_model_args(self, tokenizer, config): if self.task_type == 'seq_cls': self.problem_type = self.problem_type or getattr(config, 'problem_type', None) logger.info(f'args.problem_type: {self.problem_type}') - kwargs = self.megatron_model_meta.convert_hf_config(config) + kwargs = convert_hf_config(config) if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 self.initialize_embedding = True @@ -75,8 +76,13 @@ def __post_init__(self): if self.num_workers > 1: self.num_workers = 1 logger.info('Using streaming dataset, setting args.num_workers to 1.') - if self.load is None and self.no_initialization: + if self.load is None and self.no_initialization and not self.load_hf_checkpoint: raise ValueError('You did not pass `--load`, so you need to set `--no_initialization false` ' 'to allow the model to initialize weights properly.') if self.cached_dataset and self.context_parallel_size > 1: raise ValueError('`cached_dataset` does not support context parallelism.') + + def get_model_kwargs(self): + res = super().get_model_kwargs() + res['download_model'] = self.load_hf_checkpoint + return res diff --git a/swift/megatron/utils/convert.py b/swift/megatron/convert.py similarity index 96% rename from swift/megatron/utils/convert.py rename to swift/megatron/convert.py index aa3202f580..7d94f96fc0 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/convert.py @@ -16,9 +16,9 @@ from swift.llm import (ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint, to_device, to_float_dtype) from swift.utils import get_logger, get_n_params_grads -from ..argument import MegatronArguments -from ..model import get_megatron_model_meta -from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard +from .argument import MegatronArguments +from .model import get_megatron_model_meta +from .utils import convert_hf_config, patch_torch_dist_shard logger = get_logger() @@ -238,7 +238,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + kwargs = convert_hf_config(processor.model_info.config) logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() @@ -246,7 +246,6 @@ def convert_hf2mcore(args: ExportArguments) -> None: current_convert_kwargs['moe_grouped_gemm'] = True megatron_args = MegatronArguments( **kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() extra_args['model_info'] = args.model_info extra_args['model_meta'] = args.model_meta @@ -256,7 +255,8 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') - megatron_model_meta.convert_hf2mcore(hf_model, mg_model) + bridge = megatron_model_meta.bridge_cls() + bridge.load_state_dict(mg_model, bridge.convert_hf2mcore(hf_model.state_dict())) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) del hf_model @@ -274,7 +274,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + kwargs = convert_hf_config(processor.model_info.config) logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() @@ -291,7 +291,6 @@ def convert_mcore2hf(args: ExportArguments) -> None: **current_convert_kwargs, save=args.output_dir if args.to_mcore else None, torch_dtype=args.torch_dtype) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() extra_args['model_info'] = args.model_info extra_args['model_meta'] = args.model_meta @@ -312,7 +311,8 @@ def convert_mcore2hf(args: ExportArguments) -> None: logger.info('Megatron model created successfully.') if args.to_hf: hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) - megatron_model_meta.convert_mcore2hf(hf_model, mg_model) + bridge = megatron_model_meta.bridge_cls() + bridge.load_state_dict(hf_model, bridge.convert_mcore2hf(mg_model.state_dict())) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) del mg_model diff --git a/swift/megatron/init.py b/swift/megatron/init.py index e5dfcd7725..d1892f1cba 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -379,6 +379,15 @@ def sharded_state_dict( TEGroupedLinear.sharded_state_dict = sharded_state_dict +def _patch_megatron_tokenizer(): + from megatron.training import global_vars + + def build_tokenizer(args): + return 'dummy_tokenizer' + + global_vars.build_tokenizer = build_tokenizer + + def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'): from peft.utils import other as peft_module @@ -664,6 +673,7 @@ def _patch_megatron(): _patch_compile_helpers() _patch_build_train_valid_test_datasets() _patch_mrope() + _patch_megatron_tokenizer() logging.root.setLevel(logging_level) # revert logger level from swift.megatron import tuners # patch lora try: diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index 3c882c9864..35b8777147 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import gpt, mm_gpt +# from . import gpt, mm_gpt +from . import gpt from .constant import MegatronModelType from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py deleted file mode 100644 index 7a68537efc..0000000000 --- a/swift/megatron/model/config.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict - -from swift.utils import get_logger - -logger = get_logger() -config_mapping = { - 'num_layers': ['num_hidden_layers'], - 'hidden_size': ['hidden_size'], - 'ffn_hidden_size': ['intermediate_size'], - 'num_attention_heads': ['num_attention_heads'], - 'num_query_groups': ['num_key_value_heads'], - 'max_position_embeddings': ['max_position_embeddings'], - 'norm_epsilon': ['rms_norm_eps'], - 'rotary_base': ['rope_theta'], - 'padded_vocab_size': ['vocab_size'], - 'attention_dropout': ['attention_dropout'], - 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], - 'swiglu': ['hidden_act'], - 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], - 'disable_bias_linear': ['mlp_bias'], - 'kv_channels': ['head_dim', 'v_head_dim'], - 'architectures': ['architectures'], - # moe - 'moe_ffn_hidden_size': ['moe_intermediate_size'], - 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], - 'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk', 'moe_k'], - 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts'], - 'moe_router_pre_softmax': ['norm_topk_prob'], - # deepseek - 'q_lora_rank': ['q_lora_rank'], - 'kv_lora_rank': ['kv_lora_rank'], - 'moe_router_score_function': ['scoring_func'], - 'qk_head_dim': ['qk_nope_head_dim'], - 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], - 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], - 'qk_layernorm': ['use_qk_norm'], - # qwen3_next - 'linear_num_value_heads': ['linear_num_value_heads'], - 'linear_num_key_heads': ['linear_num_key_heads'], - 'linear_key_head_dim': ['linear_key_head_dim'], - 'linear_value_head_dim': ['linear_value_head_dim'], - 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], - 'full_attention_interval': ['full_attention_interval'], - # other - 'original_max_position_embeddings': ['original_max_position_embeddings'], - 'partial_rotary_factor': ['partial_rotary_factor'], - 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], - 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], -} - - -def convert_hf_config(config, _internal_call=False) -> Dict[str, Any]: - megatron_config = {} - for k, hf_keys in config_mapping.items(): - for hf_k in hf_keys: - if hasattr(config, hf_k): - hf_v = getattr(config, hf_k) - if hf_v is None: - continue - if k == 'rotary_base': - megatron_config[k] = int(hf_v) - elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: - megatron_config[k] = not hf_v - elif k == 'swiglu': - if hf_v == 'silu': - megatron_config[k] = True - else: - if k == 'kv_lora_rank': - megatron_config['multi_latent_attention'] = True - elif k == 'architectures': - if _internal_call: - k = 'llm_architectures' - megatron_config[k] = hf_v - break - for key in ['text_config', 'llm_config', 'thinker_config']: - if hasattr(config, key): - megatron_config.update(convert_hf_config(getattr(config, key), _internal_call=True)) - # compat llama3 - if getattr(config, 'rope_scaling', None) is not None: - if isinstance(config.rope_scaling, int): - megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, - elif isinstance(config.rope_scaling, dict): - megatron_config['rope_scaling'] = config.rope_scaling - return megatron_config diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index f3eb68e0cd..9dd370ccfd 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -4,9 +4,6 @@ from ..gpt_model import GPTModel from ..register import MegatronModelMeta, register_megatron_model from . import qwen3_next -from .config import convert_gpt_hf_config -from .hf2mcore import convert_hf2mcore -from .mcore2hf import convert_mcore2hf register_megatron_model( MegatronModelMeta( @@ -58,8 +55,4 @@ ModelType.deepseek_v3_1, ModelType.ernie_thinking, ], - model_cls=GPTModel, - convert_hf_config=convert_gpt_hf_config, - convert_mcore2hf=convert_mcore2hf, - convert_hf2mcore=convert_hf2mcore, )) diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py deleted file mode 100644 index cc8960163b..0000000000 --- a/swift/megatron/model/gpt/hf2mcore.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Optional - -import torch -from megatron.training import get_args -from torch import nn - - -def set_mla_attn_state(args, mg_attn, hf_attn): - mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight) - if args.q_lora_rank is None: - mg_attn.linear_q_proj.weight.data.copy_(hf_attn.q_proj.weight) - else: - mg_attn.linear_q_down_proj.weight.data.copy_(hf_attn.q_a_proj.weight) - mg_attn.linear_q_up_proj.weight.data.copy_(hf_attn.q_b_proj.weight) - mg_attn.linear_kv_down_proj.weight.data.copy_(hf_attn.kv_a_proj_with_mqa.weight) - mg_attn.linear_kv_up_proj.weight.data.copy_(hf_attn.kv_b_proj.weight) - if args.qk_layernorm: - mg_attn.linear_kv_up_proj.layer_norm_weight.data.copy_(hf_attn.kv_a_layernorm.weight) - - -def set_attn_state(args, mg_attn, hf_attn): - num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - - # Copy weights - mg_attn.linear_qkv.weight.data.copy_( - torch.cat([ - hf_attn.q_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - hf_attn.k_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - hf_attn.v_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - ], - dim=1).reshape((-1, args.hidden_size))) - mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight) - - # Copy bias - if args.add_qkv_bias: - mg_attn.linear_qkv.bias.data.copy_( - torch.cat([ - hf_attn.q_proj.bias.reshape((num_query_groups, -1)), - hf_attn.k_proj.bias.reshape((num_query_groups, -1)), - hf_attn.v_proj.bias.reshape((num_query_groups, -1)), - ], - dim=1).reshape(-1)) - if args.qk_layernorm: - q_norm = hf_attn.query_layernorm if hasattr(hf_attn, 'query_layernorm') else hf_attn.q_norm - k_norm = hf_attn.key_layernorm if hasattr(hf_attn, 'key_layernorm') else hf_attn.k_norm - mg_attn.q_layernorm.weight.data.copy_(q_norm.weight) - mg_attn.k_layernorm.weight.data.copy_(k_norm.weight) - - -def _set_mlp_state(mg_mlp, hf_mlp, group_idx: Optional[int] = None): - hf_grouped = not isinstance(hf_mlp.down_proj, nn.Module) - if group_idx is None: - linear_fc1_weight = mg_mlp.linear_fc1.weight - linear_fc2_weight = mg_mlp.linear_fc2.weight - else: - linear_fc1_weight = getattr(mg_mlp.linear_fc1, f'weight{group_idx}') - linear_fc2_weight = getattr(mg_mlp.linear_fc2, f'weight{group_idx}') - if hf_grouped: - linear_fc1_weight.data.copy_(hf_mlp.gate_up_proj[group_idx].t()) - linear_fc2_weight.data.copy_(hf_mlp.down_proj[group_idx].t()) - else: - if hasattr(hf_mlp, 'gate_up_proj'): - linear_fc1_weight.data.copy_(hf_mlp.gate_up_proj.weight) - else: - linear_fc1_weight.data.copy_(torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0)) - linear_fc2_weight.data.copy_(hf_mlp.down_proj.weight) - - -def _set_moe_state(args, mg_mlp, hf_mlp): - hf_gate = hf_mlp.gate - if hasattr(hf_gate, 'wg'): - hf_gate = hf_gate.wg - mg_mlp.router.weight.data.copy_(hf_gate.weight) - if args.moe_router_enable_expert_bias: - mg_mlp.router.expert_bias.data.copy_(hf_gate.e_score_correction_bias) - if mg_mlp.shared_experts is not None: - if hasattr(hf_mlp, 'shared_experts'): - hf_shared_expert = hf_mlp.shared_experts - elif hasattr(hf_mlp, 'shared_mlp'): - hf_shared_expert = hf_mlp.shared_mlp - else: - hf_shared_expert = hf_mlp.shared_expert - _set_mlp_state(mg_mlp.shared_experts, hf_shared_expert) - if mg_mlp.shared_experts.gate_weight is not None: - mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight) - for expert_idx in range(args.num_experts): - hf_expert = hf_mlp.experts - if hasattr(hf_expert, '__len__'): - hf_expert = hf_expert[expert_idx] - _set_mlp_state(mg_mlp.experts, hf_expert, group_idx=expert_idx) - - -def set_mlp_state(args, mg_mlp, hf_mlp): - if 'moe' in mg_mlp.__class__.__name__.lower(): - _set_moe_state(args, mg_mlp, hf_mlp) - else: - _set_mlp_state(mg_mlp, hf_mlp) - - -def set_layer_state(args, mg_model, hf_model, layer_idx): - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.layers[layer_idx] - if args.multi_latent_attention: - set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - mg_layer.input_layernorm.weight.data.copy_(hf_layer.input_layernorm.weight) - else: - set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - mg_layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - - post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight - if 'moe' in mg_layer.mlp.__class__.__name__.lower(): - mg_layer.pre_mlp_layernorm.weight.data.copy_(post_attention_layernorm_weight) - else: - mg_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(post_attention_layernorm_weight) - - -def convert_hf2mcore(hf_model, mg_model): - args = get_args() - mg_model.embedding.word_embeddings.weight.data.copy_(hf_model.model.embed_tokens.weight) - if args.untie_embeddings_and_output_weights: - mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight) - for layer_idx in range(args.num_layers): - set_layer_state(args, mg_model, hf_model.model, layer_idx) diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py deleted file mode 100644 index eac8023801..0000000000 --- a/swift/megatron/model/gpt/mcore2hf.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Optional - -from megatron.training import get_args -from torch import nn - - -def set_mla_attn_state(args, mg_attn, hf_attn): - hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight) - if args.q_lora_rank is None: - hf_attn.q_proj.weight.data.copy_(mg_attn.linear_q_proj.weight) - else: - hf_attn.q_a_proj.weight.data.copy_(mg_attn.linear_q_down_proj.weight) - hf_attn.q_b_proj.weight.data.copy_(mg_attn.linear_q_up_proj.weight) - hf_attn.kv_a_proj_with_mqa.weight.data.copy_(mg_attn.linear_kv_down_proj.weight) - hf_attn.kv_b_proj.weight.data.copy_(mg_attn.linear_kv_up_proj.weight) - if args.qk_layernorm: - hf_attn.kv_a_layernorm.weight.data.copy_(mg_attn.linear_kv_up_proj.layer_norm_weight) - - -def set_attn_state(args, mg_attn, hf_attn): - num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - # Copy weights - mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size)) - q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ - 0] // num_query_groups - hf_attn.q_proj.weight.data.copy_(mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size)) - hf_attn.k_proj.weight.data.copy_(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size)) - hf_attn.v_proj.weight.data.copy_(mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size)) - hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight) - - # Copy bias - if args.add_qkv_bias: - mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1)) - hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1)) - hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1)) - hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1)) - - if args.qk_layernorm: - q_norm = hf_attn.query_layernorm if hasattr(hf_attn, 'query_layernorm') else hf_attn.q_norm - k_norm = hf_attn.key_layernorm if hasattr(hf_attn, 'key_layernorm') else hf_attn.k_norm - q_norm.weight.data.copy_(mg_attn.q_layernorm.weight) - k_norm.weight.data.copy_(mg_attn.k_layernorm.weight) - - -def _set_moe_state(args, mg_mlp, hf_mlp): - hf_gate = hf_mlp.gate - if hasattr(hf_gate, 'wg'): - hf_gate = hf_gate.wg - hf_gate.weight.data.copy_(mg_mlp.router.weight) - if args.moe_router_enable_expert_bias: - hf_gate.e_score_correction_bias.data.copy_(mg_mlp.router.expert_bias) - if mg_mlp.shared_experts is not None: - if hasattr(hf_mlp, 'shared_experts'): - hf_shared_expert = hf_mlp.shared_experts - elif hasattr(hf_mlp, 'shared_mlp'): - hf_shared_expert = hf_mlp.shared_mlp - else: - hf_shared_expert = hf_mlp.shared_expert - _set_mlp_state(mg_mlp.shared_experts, hf_shared_expert) - if mg_mlp.shared_experts.gate_weight is not None: - hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight) - for expert_idx in range(args.num_experts): - hf_expert = hf_mlp.experts - if hasattr(hf_expert, '__len__'): - hf_expert = hf_expert[expert_idx] - _set_mlp_state(mg_mlp.experts, hf_expert, group_idx=expert_idx) - - -def _set_mlp_state(mg_mlp, hf_mlp, group_idx: Optional[int] = None): - hf_grouped = not isinstance(hf_mlp.down_proj, nn.Module) - if group_idx is None: - linear_fc1_weight = mg_mlp.linear_fc1.weight - linear_fc2_weight = mg_mlp.linear_fc2.weight - else: - linear_fc1_weight = getattr(mg_mlp.linear_fc1, f'weight{group_idx}') - linear_fc2_weight = getattr(mg_mlp.linear_fc2, f'weight{group_idx}') - - if hf_grouped: - hf_mlp.gate_up_proj.data[group_idx] = linear_fc1_weight.t() - hf_mlp.down_proj.data[group_idx] = linear_fc2_weight.t() - else: - if hasattr(hf_mlp, 'gate_up_proj'): - hf_mlp.gate_up_proj.weight.data.copy_(linear_fc1_weight) - else: - ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0] - hf_mlp.gate_proj.weight.data.copy_(linear_fc1_weight[:ffn_hidden_size]) - hf_mlp.up_proj.weight.data.copy_(linear_fc1_weight[ffn_hidden_size:]) - hf_mlp.down_proj.weight.data.copy_(linear_fc2_weight) - - -def set_mlp_state(args, mg_mlp, hf_mlp): - if 'moe' in mg_mlp.__class__.__name__.lower(): - _set_moe_state(args, mg_mlp, hf_mlp) - else: - _set_mlp_state(mg_mlp, hf_mlp) - - -def set_layer_state(args, mg_model, hf_model, layer_idx): - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.layers[layer_idx] - - if args.multi_latent_attention: - set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.input_layernorm.weight) - else: - set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - - post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight - if 'moe' in mg_layer.mlp.__class__.__name__.lower(): - post_attention_layernorm_weight.data.copy_(mg_layer.pre_mlp_layernorm.weight) - else: - post_attention_layernorm_weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight) - - -def convert_mcore2hf(hf_model, mg_model): - args = get_args() - hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) - if args.untie_embeddings_and_output_weights: - lm_head_weight = hf_model.score.weight if args.task_type == 'seq_cls' else hf_model.lm_head.weight - lm_head_weight.data.copy_(mg_model.output_layer.weight) - hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) - for layer_idx in range(args.num_layers): - set_layer_state(args, mg_model, hf_model.model, layer_idx) diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index 1d291b6a4c..2dd9026dd9 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -3,7 +3,6 @@ from typing import Optional, Tuple, Union import torch -from megatron.core import mpu from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb @@ -22,9 +21,8 @@ from swift.llm import ModelType from swift.utils import get_logger from ..constant import MegatronModelType -from ..gpt_model import GPTModel +from ..gpt_bridge import GPTBridge from ..register import MegatronModelMeta, register_megatron_model -from .config import convert_gpt_hf_config try: from flashattn_hopper.flash_attn_interface import _flash_attn_forward @@ -473,61 +471,26 @@ def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): return block_spec -def convert_mcore2hf_qwen3_next(hf_model, mg_model): - from .mcore2hf import set_mlp_state, set_attn_state - args = get_args() - hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) - if args.untie_embeddings_and_output_weights: - lm_head_weight = hf_model.score.weight if args.task_type == 'seq_cls' else hf_model.lm_head.weight - lm_head_weight.data.copy_(mg_model.output_layer.weight) - hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight - 1) - for layer_idx in range(args.num_layers): - layer_type = args.layer_types[layer_idx] - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] - mg_attn = mg_layer.self_attention - - if layer_type == 'linear_attention': - hf_layer.linear_attn.load_state_dict(mg_attn.state_dict(), strict=False) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.input_layernorm.weight - 1) - elif layer_type == 'full_attention': - hf_attn = hf_layer.self_attn - set_attn_state(args, mg_attn, hf_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_attn.linear_qkv.layer_norm_weight - 1) - if args.qk_layernorm: - hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight - 1) - hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight - 1) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - hf_layer.post_attention_layernorm.weight.data.copy_(mg_layer.pre_mlp_layernorm.weight - 1) +class Qwen3NextBridge(GPTBridge): + def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool, offset: float = 0): + if hf_key in { + 'model.norm.weight', 'q_norm.weight', 'k_norm.weight', 'input_layernorm.weight', + 'post_attention_layernorm.weight' + }: + offset = -1 if reverse else 1 + else: + assert 'norm' not in hf_key, f'hf_key: {hf_key}' # just check + return super()._set_state_dict(state_dict, res_state_dict, hf_key, mg_key, reverse, offset) -def convert_hf2mcore_qwen3_next(hf_model, mg_model): - from .hf2mcore import set_mlp_state, set_attn_state - args = get_args() - mg_model.embedding.word_embeddings.weight.data.copy_(hf_model.model.embed_tokens.weight) - if args.untie_embeddings_and_output_weights: - mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight + 1) - for layer_idx in range(args.num_layers): - layer_type = args.layer_types[layer_idx] - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] - mg_attn = mg_layer.self_attention - + def _set_layer_attn(self, state_dict, layer_idx: int, reverse: bool): + layer_type = self.args.layer_types[layer_idx] if layer_type == 'linear_attention': - mg_attn.load_state_dict(hf_layer.linear_attn.state_dict(), strict=False) - mg_layer.input_layernorm.weight.data.copy_(hf_layer.input_layernorm.weight + 1) + res = self._replace_prefix(state_dict, 'linear_attn.', 'self_attention.', reverse) + self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'input_layernorm.weight', reverse) elif layer_type == 'full_attention': - hf_attn = hf_layer.self_attn - set_attn_state(args, mg_attn, hf_attn) - mg_attn.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight + 1) - if args.qk_layernorm: - mg_attn.q_layernorm.weight.data.copy_(hf_attn.q_norm.weight + 1) - mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight + 1) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - mg_layer.pre_mlp_layernorm.weight.data.copy_(hf_layer.post_attention_layernorm.weight + 1) + res = super()._set_layer_attn(state_dict, layer_idx, reverse) + return res register_megatron_model( @@ -537,9 +500,6 @@ def convert_hf2mcore_qwen3_next(hf_model, mg_model): ModelType.qwen3_next, ModelType.qwen3_next_thinking, ], - model_cls=GPTModel, - convert_hf_config=convert_gpt_hf_config, get_transformer_layer_spec=get_qwen3_next_transformer_layer_spec, - convert_mcore2hf=convert_mcore2hf_qwen3_next, - convert_hf2mcore=convert_hf2mcore_qwen3_next, + bridge_cls=Qwen3NextBridge, )) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py new file mode 100644 index 0000000000..f0a869cde9 --- /dev/null +++ b/swift/megatron/model/gpt_bridge.py @@ -0,0 +1,284 @@ +from typing import Dict, Optional + +import torch +from megatron.training import get_args +from tqdm import tqdm + +from swift.llm import deep_getattr, get_model_tokenizer +from swift.utils import disable_safe_ddp_context_use_barrier + + +class GPTBridge: + lm_layers_prefix = 'model.layers' # HF model + + def __init__(self): + self.args = get_args() + model_info = self.args.model_info + with torch.device('meta'), disable_safe_ddp_context_use_barrier(): + self.hf_model, _ = get_model_tokenizer( + model_info.model_dir, model_type=model_info.model_type, return_dummy_model=True) + self.hf_layers = deep_getattr(self.hf_model, self.lm_layers_prefix) + + def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool, offset: float = 0): + src_key, tgt_key = hf_key, mg_key + if reverse: + src_key, tgt_key = tgt_key, src_key + res_state_dict[tgt_key] = state_dict[src_key] + if offset: + res_state_dict[tgt_key] = res_state_dict[tgt_key] + offset + + @staticmethod + def _remove_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + + @staticmethod + def _add_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {f'{prefix}{k}': v for k, v in state_dict.items()} + + @staticmethod + def _filter_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {k: v for k, v in state_dict.items() if k.startswith(prefix)} + + @staticmethod + def _replace_prefix(state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + res = GPTBridge._remove_prefix(state_dict, src_prefix) + return GPTBridge._add_prefix(res, tgt_prefix) + + @staticmethod + def _is_moe(state_dict): + for k, v in state_dict.items(): + if 'experts.' in k: + return True + return False + + def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, layer_idx: int, reverse: bool): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + hf_attn = self.hf_layers[layer_idx].self_attn + args = self.args + res = {} + num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) + if reverse: + mg_attn_weight = state_dict['linear_qkv.weight'].reshape((num_query_groups, -1, args.hidden_size)) + q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ + 0] // num_query_groups + res['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) + res['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size) + res['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size) + else: + res['linear_qkv.weight'] = torch.cat([ + state_dict['q_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + state_dict['k_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + state_dict['v_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + ], + dim=1).reshape((-1, args.hidden_size)) + self._set_state_dict(state_dict, res, 'o_proj.weight', 'linear_proj.weight', reverse) + + # Copy bias + if args.add_qkv_bias: + if reverse: + mg_attn_bias = state_dict['linear_qkv.bias'].reshape((num_query_groups, -1)) + res['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) + res['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) + res['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1) + else: + res['linear_qkv.bias'] = torch.cat([ + state_dict['q_proj.bias'].reshape((num_query_groups, -1)), + state_dict['k_proj.bias'].reshape((num_query_groups, -1)), + state_dict['v_proj.bias'].reshape((num_query_groups, -1)), + ], + dim=1).reshape(-1) + if args.qk_layernorm: + hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight' + hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight' + self._set_state_dict(state_dict, res, hf_q_norm_key, 'q_layernorm.weight', reverse) + self._set_state_dict(state_dict, res, hf_k_norm_key, 'k_layernorm.weight', reverse) + + return self._add_prefix(res, tgt_prefix) + + def _set_moe_state(self, state_dict, hf_prefix: str, mg_prefix: str, layer_idx: int, reverse: bool): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + hf_mlp = self.hf_layers[layer_idx].mlp + res = {} + hf_gate_key = 'gate.wg.weight' if hasattr(hf_mlp.gate, 'wg') else 'gate.weight' + self._set_state_dict(state_dict, res, hf_gate_key, 'router.weight', reverse) + if self.args.moe_router_enable_expert_bias: + self._set_state_dict(state_dict, res, 'gate.e_score_correction_bias', 'router.expert_bias', reverse) + + if self.args.moe_shared_expert_intermediate_size: + for key in ['shared_expert', 'shared_experts', 'shared_mlp']: + if hasattr(hf_mlp, key): + hf_shared_expert_prefix = f'{key}.' + res.update(self._set_mlp_state(state_dict, hf_shared_expert_prefix, 'shared_experts.', layer_idx, reverse)) + if hasattr(hf_mlp, 'shared_expert_gate'): + self._set_state_dict(state_dict, res, 'shared_expert_gate.weight', 'shared_experts.gate_weight', + reverse) + for expert_idx in range(self.args.num_experts): + hf_expert_prefix = f'experts.{expert_idx}.' if hasattr(hf_mlp.experts, '__len__') else 'experts.' + res.update( + self._set_mlp_state(state_dict, hf_expert_prefix, 'experts.', layer_idx, reverse, group_idx=expert_idx)) + return self._add_prefix(res, tgt_prefix) + + def _set_mlp_state( + self, + state_dict, + hf_prefix: str, + mg_prefix: str, + layer_idx: int, + reverse: bool, + group_idx: Optional[int] = None, + ): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + hf_mlp = self.hf_layers[layer_idx].mlp + hf_grouped = False + if group_idx is not None and not hasattr(hf_mlp.experts, '__len__'): + hf_grouped = True + res = {} + # Determines the keys for fc1 and fc2 in megatron + if group_idx is None: + fc1_key = 'linear_fc1.weight' + fc2_key = 'linear_fc2.weight' + else: + fc1_key = f'linear_fc1.weight{group_idx}' + fc2_key = f'linear_fc2.weight{group_idx}' + if hf_grouped: + res[fc1_key] = state_dict['gate_up_proj'][group_idx].t() + res[fc2_key] = state_dict['down_proj'][group_idx].t() + else: + if hasattr(hf_mlp, 'gate_up_proj'): + self._set_state_dict(state_dict, res, 'gate_up_proj.weight', fc1_key, reverse) + else: + if reverse: + ffn_hidden_size = state_dict[fc1_key].shape[0] // 2 + res['gate_proj.weight'] = state_dict[fc1_key][:ffn_hidden_size] + res['up_proj.weight'] = state_dict[fc1_key][ffn_hidden_size:] + else: + res[fc1_key] = torch.cat([ + state_dict['gate_proj.weight'], + state_dict['up_proj.weight'], + ], dim=0) + self._set_state_dict(state_dict, res, 'down_proj.weight', fc2_key, reverse) + return self._add_prefix(res, tgt_prefix) + + def _set_mla_attn_state( + self, + state_dict, + hf_prefix: str, + mg_prefix: str, + layer_idx: int, + reverse: bool, + ): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + res = {} + self._set_state_dict(state_dict, res, 'o_proj.weight', 'linear_proj.weight', reverse) + if self.args.q_lora_rank is None: + self._set_state_dict(state_dict, res, 'q_proj.weight', 'linear_q_proj.weight', reverse) + else: + self._set_state_dict(state_dict, res, 'q_a_proj.weight', 'linear_q_down_proj.weight', reverse) + self._set_state_dict(state_dict, res, 'q_b_proj.weight', 'linear_q_up_proj.weight', reverse) + self._set_state_dict(state_dict, res, 'kv_a_proj_with_mqa.weight', 'linear_kv_down_proj.weight', reverse) + self._set_state_dict(state_dict, res, 'kv_b_proj.weight', 'linear_kv_up_proj.weight', reverse) + if self.args.qk_layernorm: + self._set_state_dict(state_dict, res, 'kv_a_layernorm.weight', 'linear_kv_up_proj.layer_norm_weight', + reverse) + return self._add_prefix(res, tgt_prefix) + + def _set_layer_attn(self, state_dict, layer_idx: int, reverse: bool): + res = {} + if self.args.multi_latent_attention: + res.update(self._set_mla_attn_state(state_dict, 'self_attn.', 'self_attention.', layer_idx, reverse)) + self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'input_layernorm.weight', reverse) + else: + res.update(self._set_attn_state(state_dict, 'self_attn.', 'self_attention.', layer_idx, reverse)) + self._set_state_dict(state_dict, res, 'input_layernorm.weight', + 'self_attention.linear_qkv.layer_norm_weight', reverse) + return res + + def _set_layer_mlp(self, state_dict, layer_idx: int, reverse: bool): + hf_mlp = self.hf_layers[layer_idx].mlp + res = {} + is_moe = self._is_moe(hf_mlp.state_dict()) + if is_moe: + res.update(self._set_moe_state(state_dict, 'mlp.', 'mlp.', layer_idx, reverse)) + self._set_state_dict(state_dict, res, 'post_attention_layernorm.weight', 'pre_mlp_layernorm.weight', + reverse) + else: + res.update(self._set_mlp_state(state_dict, 'mlp.', 'mlp.', layer_idx, reverse)) + self._set_state_dict(state_dict, res, 'post_attention_layernorm.weight', 'mlp.linear_fc1.layer_norm_weight', + reverse) + return res + + def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix: str, reverse: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + mg_prefix = f'{mg_prefix}{layer_idx}.' + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + res = self._set_layer_attn(state_dict, layer_idx, reverse) + res.update(self._set_layer_mlp(state_dict, layer_idx, reverse)) + return self._add_prefix(res, tgt_prefix) + + def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): + """reverse: False: hf -> mg; True: mg -> hf""" + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + res = {} + self._set_state_dict(state_dict, res, 'model.embed_tokens.weight', 'embedding.word_embeddings.weight', reverse) + if self.args.untie_embeddings_and_output_weights: + hf_lm_head_key = 'lm_head.weight' + if reverse and self.args.task_type == 'seq_cls': + hf_lm_head_key = 'score.weight' + self._set_state_dict(state_dict, res, hf_lm_head_key, 'output_layer.weight', reverse) + self._set_state_dict(state_dict, res, 'model.norm.weight', 'decoder.final_layernorm.weight', reverse) + for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): + res.update(self._set_layer_state(state_dict, layer_idx, 'model.layers.', 'decoder.layers.', reverse)) + return self._add_prefix(res, tgt_prefix) + + def convert_hf2mcore(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return self._convert(state_dict, '', '', False) + + def convert_mcore2hf(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return self._convert(state_dict, '', '', True) + + def load_state_dict(self, model, state_dict) -> None: + """The model can be either hf_model or mg_model""" + incompatible_keys = model.load_state_dict(state_dict, strict=False) + missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] + assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' + assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' + + def load_from_hf_checkpoint(self, mg_model, hf_model_dir: str) -> None: + """按照mg_model的模型结构, 加载需要的参数,并进行scatter""" + print() + + def get_hf_state_dict(self, mg_models) -> Dict[str, torch.Tensor]: + """获取完整的hf state_dict""" + print() + + def save_hf_checkpoint(self, mg_models, output_dir: str) -> None: + """保存mg_model的hf格式checkpoint""" + state_dict = get_hf_state_dict(mg_models) + # rank0 save() diff --git a/swift/megatron/model/mm_gpt/qwen3_vl.py b/swift/megatron/model/mm_gpt/qwen3_vl.py index de222f75d7..b0ce9b8f02 100644 --- a/swift/megatron/model/mm_gpt/qwen3_vl.py +++ b/swift/megatron/model/mm_gpt/qwen3_vl.py @@ -15,7 +15,7 @@ from swift.llm import ModelType, to_device from ..constant import MegatronModelType -from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore +from ..gpt.hf2mcore import _add_prefix, _remove_prefix, convert_hf2mcore from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf from ..mm_gpt_model import MultimodalGPTModel from ..register import register_megatron_model @@ -500,17 +500,14 @@ def __init__(self, *args, **kwargs): visual_cls=Qwen3Omni_Vit)) -def convert_hf2mcore_qwen3_vl(hf_model, mg_model): - language_model = hf_model.model.language_model - mg_language_model = mg_model.language_model +def convert_hf2mcore_qwen3_vl(state_dict, prefix=''): args = get_args() - mg_language_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) + mg_state_dict = {} if args.untie_embeddings_and_output_weights: - mg_language_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_language_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight) - for layer_idx in range(args.num_layers): - set_layer_state_hf2mcore(args, mg_language_model, language_model, layer_idx) - mg_model.visual.visual.load_state_dict(hf_model.model.visual.state_dict()) + mg_state_dict['language_model.output_layer.weight'] = state_dict['lm_head.weight'] + mg_state_dict.update(convert_hf2mcore(state_dict, 'language_model.')) + mg_state_dict.update(_add_prefix(_remove_prefix(state_dict, 'model.visual.'), 'visual.visual.')) + return _add_prefix(mg_state_dict, prefix) def convert_mcore2hf_qwen3_vl(hf_model, mg_model): diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 814cd82213..392785e085 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -14,14 +14,12 @@ if TYPE_CHECKING: from .gpt_model import GPTModel - from .mm_gpt_model import MultimodalGPTModel # Code borrowed from NVIDIA/Megatron-LM -def model_provider( - pre_process=True, - post_process=True, - vp_stage: Optional[int] = None) -> Union['GPTModel', 'MultimodalGPTModel', megatron.legacy.model.GPTModel]: +def model_provider(pre_process=True, + post_process=True, + vp_stage: Optional[int] = None) -> Union['GPTModel', megatron.legacy.model.GPTModel]: """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 93f892c2e8..877a09db92 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,12 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from argparse import ArgumentParser from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Callable, List, Optional, Type import torch.nn as nn -from transformers import PretrainedConfig from swift.llm import MODEL_MAPPING +from .constant import MLLMMegatronModelType +from .gpt_bridge import GPTBridge +from .gpt_model import GPTModel +from .mm_gpt_model import MultimodalGPTModel from .model_provider import model_provider as model_provider_func MEGATRON_MODEL_MAPPING = {} @@ -17,17 +20,18 @@ class MegatronModelMeta: megatron_model_type: str model_types: List[str] - convert_mcore2hf: Callable[[nn.Module, nn.Module], None] - convert_hf2mcore: Callable[[nn.Module, nn.Module], None] - - model_cls: Type[nn.Module] - convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] + is_multimodal: bool = False + bridge_cls: Type[GPTBridge] = GPTBridge get_transformer_layer_spec: Optional[Callable] = None model_provider: Callable[[], nn.Module] = model_provider_func visual_cls: Optional[Type[nn.Module]] = None extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None + @property + def model_cls(self): + return MultimodalGPTModel if self.is_multimodal else GPTModel + def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): megatron_model_type = megatron_model_meta.megatron_model_type @@ -36,7 +40,8 @@ def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: model_meta.support_megatron = True if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.') - + if megatron_model_type in MLLMMegatronModelType.__dict__: + megatron_model_meta.is_multimodal = True MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 15b81434f9..d71eebdc68 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -10,7 +10,6 @@ from swift.utils import get_logger, is_master, plot_images from ..argument import MegatronTrainArguments from ..trainers import MegatronTrainer -from ..utils import patch_megatron_tokenizer from .utils import build_streaming_dataloader logger = get_logger() @@ -35,7 +34,6 @@ def __init__(self, args: Optional[Union[List[str], MegatronTrainArguments]] = No with torch.device('meta'): self.model, self.processor = args.get_model_processor(**kwargs) self._prepare_template() - patch_megatron_tokenizer(self.processor) args.init_model_args(self.tokenizer, self.processor.model_info.config) args.save_args(args.save) self.template.use_megatron = True diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 52a6f7077f..5ea4613846 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -47,6 +47,7 @@ def __init__(self, args, template): self.stimer = StragglerDetector() self.unwrapped_models = [] self.peft_models = [] + self.bridge = args.megatron_model_meta.bridge_cls() logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') # for evaluate @@ -251,13 +252,16 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs): def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - def new_model_provider_func(*args, **kwargs): - model = model_provider_func(*args, **kwargs) + args = get_args() + + def new_model_provider_func(*_args, **kwargs): + model = model_provider_func(*_args, **kwargs) + if args.load_hf_checkpoint: + bridge.load_from_hf_checkpoint(model, args.model_info.model_dir) self.unwrapped_models.append(model) self.peft_models.append(prepare_mcore_model(model)) return model - args = get_args() self._init_multimodal_full(args) with self._patch_load_state_dict(self._load_base_checkpoint): model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer( @@ -724,9 +728,14 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear return report_memory_flag - def save_checkpoint(self, *args, **kwargs): - with adapter_state_dict_context(): - return self._origin_save_checkpoint(*args, **kwargs) + def save_checkpoint(self, iteration, *_args, **kwargs): + args = get_args() + if args.save_hf_checkpoint: + ouput_dir = os.path.join(args.save, f'checkpoint-{iteration}') + bridge.save_hf_checkpoint(self.unwrapped_models, ouput_dir) + else: + with adapter_state_dict_context(): + return self._origin_save_checkpoint(iteration, *_args, **kwargs) def _patch_megatron(self): # support max_epochs diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 1afd505df0..91fbd0fbf8 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .convert import convert_hf2mcore, convert_mcore2hf -from .patcher import patch_megatron_tokenizer +from .config import convert_hf_config +from .patcher import patch_torch_dist_shard from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/utils/config.py similarity index 50% rename from swift/megatron/model/gpt/config.py rename to swift/megatron/utils/config.py index e779a827b6..fba9dee26c 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/utils/config.py @@ -1,11 +1,92 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict -from ..config import convert_hf_config +from swift.utils import get_logger +logger = get_logger() +config_mapping = { + 'num_layers': ['num_hidden_layers'], + 'hidden_size': ['hidden_size'], + 'ffn_hidden_size': ['intermediate_size'], + 'num_attention_heads': ['num_attention_heads'], + 'num_query_groups': ['num_key_value_heads'], + 'max_position_embeddings': ['max_position_embeddings'], + 'norm_epsilon': ['rms_norm_eps'], + 'rotary_base': ['rope_theta'], + 'padded_vocab_size': ['vocab_size'], + 'attention_dropout': ['attention_dropout'], + 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], + 'swiglu': ['hidden_act'], + 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], + 'disable_bias_linear': ['mlp_bias'], + 'kv_channels': ['head_dim', 'v_head_dim'], + 'architectures': ['architectures'], + # moe + 'moe_ffn_hidden_size': ['moe_intermediate_size'], + 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], + 'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk', 'moe_k'], + 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts'], + 'moe_router_pre_softmax': ['norm_topk_prob'], + # deepseek + 'q_lora_rank': ['q_lora_rank'], + 'kv_lora_rank': ['kv_lora_rank'], + 'moe_router_score_function': ['scoring_func'], + 'qk_head_dim': ['qk_nope_head_dim'], + 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], + 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], + 'qk_layernorm': ['use_qk_norm'], + # qwen3_next + 'linear_num_value_heads': ['linear_num_value_heads'], + 'linear_num_key_heads': ['linear_num_key_heads'], + 'linear_key_head_dim': ['linear_key_head_dim'], + 'linear_value_head_dim': ['linear_value_head_dim'], + 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], + 'full_attention_interval': ['full_attention_interval'], + # other + 'original_max_position_embeddings': ['original_max_position_embeddings'], + 'partial_rotary_factor': ['partial_rotary_factor'], + 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], + 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], +} -def convert_gpt_hf_config(config) -> Dict[str, Any]: - res = convert_hf_config(config) + +def _convert_config(config, _internal_call=False) -> Dict[str, Any]: + megatron_config = {} + for k, hf_keys in config_mapping.items(): + for hf_k in hf_keys: + if hasattr(config, hf_k): + hf_v = getattr(config, hf_k) + if hf_v is None: + continue + if k == 'rotary_base': + megatron_config[k] = int(hf_v) + elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: + megatron_config[k] = not hf_v + elif k == 'swiglu': + if hf_v == 'silu': + megatron_config[k] = True + else: + if k == 'kv_lora_rank': + megatron_config['multi_latent_attention'] = True + elif k == 'architectures': + if _internal_call: + k = 'llm_architectures' + megatron_config[k] = hf_v + break + for key in ['text_config', 'llm_config', 'thinker_config']: + if hasattr(config, key): + megatron_config.update(convert_hf_config(getattr(config, key), _internal_call=True)) + # compat llama3 + if getattr(config, 'rope_scaling', None) is not None: + if isinstance(config.rope_scaling, int): + megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, + elif isinstance(config.rope_scaling, dict): + megatron_config['rope_scaling'] = config.rope_scaling + return megatron_config + + +def convert_hf_config(config) -> Dict[str, Any]: + res = _convert_config(config) architectures = res.get('architectures') if isinstance(architectures, list) and architectures: architectures = architectures[0] diff --git a/swift/megatron/utils/lazy_tensor.py b/swift/megatron/utils/lazy_tensor.py new file mode 100644 index 0000000000..f594b26dbf --- /dev/null +++ b/swift/megatron/utils/lazy_tensor.py @@ -0,0 +1,74 @@ +import os +from functools import partial + +import json +import safetensors.torch + + +class LazyTensor: + + def __init__(self, tensor=None, loader=None): + """You need to provide a tensor or loader""" + self.tensor = tensor + self.loader = loader + + def load(self): + if self.tensor is None: + self.tensor = self.loader() + self.loader = None + return self.tensor + + +class SafetensorsLazyLoader: + + def __init__(self, hf_model_dir: str): + self.hf_model_dir = hf_model_dir + self._weight_map = {} + self._file_handles = {} + self._load_index() + + def _open_file(self, filename: str): + """Open a safetensors file if not already open.""" + if filename not in self._file_handles: + file_path = os.path.join(self.hf_model_dir, filename) + self._file_handles[filename] = safetensors.torch.safe_open(file_path, framework='pt') + return self._file_handles[filename] + + def _load_index(self): + """Load the model index file to get weight map.""" + index_path = os.path.join(self.hf_model_dir, 'model.safetensors.index.json') + + if os.path.exists(index_path): + with open(index_path, 'r') as f: + self._index_file = json.load(f) + self._weight_map = self._index_file.get('weight_map', {}) + else: + # Single file model + safetensors_file = os.path.join(self.hf_model_dir, 'model.safetensors') + if os.path.exists(safetensors_file): + # All weights are in single file + with safetensors.torch.safe_open(safetensors_file, framework='pt') as f: + for key in f.keys(): + self._weight_map[key] = 'model.safetensors' + + def get_state_dict(self): + res = {} + for k in self._weight_map.keys(): + res[k] = LazyTensor(loader=partial(self._load_tensor, key=k)) + return res + + def _load_tensor(self, key): + filename = self._weight_map[key] + file_handle = self._open_file(filename) + return file_handle.get_tensor(key) + + def close(self): + for f in self._file_handles: + f.close() + self._file_handles.clear() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py index 49dec85dea..9fd1cbf169 100644 --- a/swift/megatron/utils/patcher.py +++ b/swift/megatron/utils/patcher.py @@ -7,14 +7,6 @@ logger = get_logger() -def patch_megatron_tokenizer(tokenizer): - - def build_tokenizer(args): - return tokenizer - - global_vars.build_tokenizer = build_tokenizer - - def patch_torch_dist_shard(thread_count): __init__ = TorchDistSaveShardedStrategy.__init__ diff --git a/tests/megatron/test_save.py b/tests/megatron/test_save.py index cfc78182ae..c19b7792e6 100644 --- a/tests/megatron/test_save.py +++ b/tests/megatron/test_save.py @@ -12,7 +12,7 @@ def get_mg_model_tokenizer(): _, processor = get_model_tokenizer(model_id, load_model=False) megatron_model_meta = get_megatron_model_meta(processor.model_meta.model_type) model_info = processor.model_info - kwargs = megatron_model_meta.convert_hf_config(model_info.config) + kwargs = convert_hf_config(model_info.config) megatron_args = MegatronArguments( **kwargs, seq_length=1, @@ -22,7 +22,6 @@ def get_mg_model_tokenizer(): save='mcore-hf-test', no_load_optim=True, no_load_rng=True) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) mg_model = megatron_model_meta.model_provider() @@ -57,5 +56,4 @@ def test_save(): from swift.utils import set_default_ddp_config from swift.megatron.argument import MegatronArguments from swift.megatron.model import get_megatron_model_meta - from swift.megatron.utils import patch_megatron_tokenizer test_save()