From 6e732f3995de771e2a1b18fec088b9ddaed4d9bc Mon Sep 17 00:00:00 2001 From: jason9693 Date: Mon, 20 Oct 2025 19:57:25 +0900 Subject: [PATCH 1/9] Enable for exporting unmerged HF Lora Adapter --- swift/megatron/model/gpt/mcore2hf.py | 7 + swift/megatron/model/gpt/mcore2hf_lora.py | 306 ++++++++++++++++++++++ swift/megatron/model/register.py | 9 + swift/megatron/utils/convert.py | 23 +- 4 files changed, 341 insertions(+), 4 deletions(-) create mode 100644 swift/megatron/model/gpt/mcore2hf_lora.py diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index eac8023801..b76e291e61 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + from typing import Optional from megatron.training import get_args @@ -125,3 +126,9 @@ def convert_mcore2hf(hf_model, mg_model): hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) for layer_idx in range(args.num_layers): set_layer_state(args, mg_model, hf_model.model, layer_idx) + + +def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups: int) -> None: + """Megatron Core LoRA 어댑터를 HuggingFace PEFT 형식으로 변환합니다.""" + from .mcore2hf_lora import convert_mcore_lora_to_hf_peft as _convert_mcore_lora_to_hf_peft + _convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir, num_groups) diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py new file mode 100644 index 0000000000..fd65e8bdf8 --- /dev/null +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -0,0 +1,306 @@ +# Copyright (c) Kakao Corp. (AI Alignment Team). +# Contact: kevin.us@kakaocorp.com + +import json +import os +from collections import OrderedDict + +from safetensors.torch import save_file +from swift.utils import get_logger + +logger = get_logger() + + +def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups: int) -> None: + """ + Megatron Core LoRA 어댑터를 HuggingFace PEFT 형식으로 변환합니다. + + Args: + peft_model: Megatron Core PEFTModel + mg_model: loaded Megatron Core Model (for shape) + dst_dir: Dir path to saving HuggingFace PEFT + num_groups: number of Attention group + """ + os.makedirs(dst_dir, exist_ok=True) + dst_model = os.path.join(dst_dir, "adapter_model.safetensors") + dst_cfg = os.path.join(dst_dir, "adapter_config.json") + + logger.info(f"Converting Megatron Core LoRA to HF PEFT format at {dst_dir}") + + # Megatron Core 모델에서 shape 추론 + logger.info("Extracting shape information from Megatron Core model...") + mg_language_model = mg_model.language_model if hasattr(mg_model, 'language_model') else mg_model + + # Megatron Core의 attention 모듈에서 shape 추론 + # 첫 번째 레이어의 attention 모듈 사용 + first_layer = mg_language_model.layers[0] + if hasattr(first_layer, 'self_attention'): + attn_module = first_layer.self_attention + else: + attn_module = first_layer.attention + + # Megatron Core의 attention shape 추론 + if hasattr(attn_module, 'linear_qkv'): + # fused qkv의 경우 + qkv_weight = attn_module.linear_qkv.weight + out_features, in_features = qkv_weight.shape + q_dim = out_features // (num_groups * 3) # q, k, v 각각 + kv_dim = q_dim + else: + # 분리된 q, k, v의 경우 + q_weight = attn_module.linear_q.weight + k_weight = attn_module.linear_k.weight + v_weight = attn_module.linear_v.weight + q_out, in_features = q_weight.shape + k_out, _ = k_weight.shape + v_out, _ = v_weight.shape + + q_dim = q_out // num_groups + kv_dim = k_out // num_groups + assert v_out // num_groups == kv_dim, "k/v group out dim mismatch" + + logger.info(f"Shape inference: num_groups={num_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}") + + # peft_model의 state_dict에서 모듈 단위 버킷화 + logger.info("Extracting LoRA weights from loaded PEFTModel...") + bucket = {} # prefix -> {local_name: tensor} + state_dict = peft_model.state_dict() + + for fullkey, tensor in state_dict.items(): + # adapter 관련 키만 처리 + if "lora_A" not in fullkey and "lora_B" not in fullkey: + continue + parts = fullkey.split(".") + local = ".".join(parts[-2:]) # e.g., lora_A.weight + prefix = ".".join(parts[:-2]) # e.g., ...linear_qkv + bucket.setdefault(prefix, {})[local] = tensor.cpu() + + dst_tensors = OrderedDict() + + def push(dst, key, tensor): + """텐서를 저장용으로 독립 복사 + 연속 메모리 보장""" + t = tensor.detach().clone().contiguous() + if key in dst: + raise ValueError(f"Duplicate key: {key}") + if "weight" not in key: + logger.debug(f"Skipping non-weight key: {key}") + return + key = remap_key_for_peft(key) + dst[key] = t + + def remap_key_for_peft(key: str) -> str: + """키를 HuggingFace PEFT 형식으로 변환""" + # 1) decoder → model + key = key.replace(".decoder.layers.", ".model.layers.") + # 2) self_attention → self_attn + key = key.replace(".self_attention.", ".self_attn.") + # 3) check prefix + if key.startswith("model.layers."): + key = "base_model.model." + key + return key + + def convert_linear_proj(prefix, tensors): + """mcore: ...self_attention.linear_proj -> HF: ...self_attn.o_proj""" + new_prefix = prefix.replace(".self_attention.linear_proj", ".self_attn.o_proj") + for local, T in tensors.items(): + push(dst_tensors, f"{new_prefix}.{local}", T) + + def convert_linear_qkv(prefix, tensors): + """ + Megatron Core fused qkv LoRA를 HF q_proj, k_proj, v_proj로 분리 + + mcore: + A: [r, in_features] (공유) + B: [num_groups*(q_dim+kv_dim+kv_dim), r] + -> HF: + q_proj: A=[r,in], B=[num_groups*q_dim, r] + k_proj: A=[r,in], B=[num_groups*kv_dim, r] + v_proj: A=[r,in], B=[num_groups*kv_dim, r] + """ + A = tensors.get("lora_A.weight", None) + B = tensors.get("lora_B.weight", None) + if A is None or B is None: + # 핵심 가중치 없으면 원키로 pass + for local, T in tensors.items(): + push(dst_tensors, f"{prefix}.{local}", T) + return + + r, in_A = A.shape + out_B, rB = B.shape + assert rB == r, f"LoRA rank mismatch: A={r}, B={rB}" + assert in_A == in_features, f"in_features mismatch: A={in_A}, base={in_features}" + + expected_out = num_groups * (q_dim + kv_dim + kv_dim) + assert out_B == expected_out, f"Fused B out({out_B}) != expected({expected_out})" + + # [num_groups, (q_dim+kv_dim+kv_dim), r]로 reshape한 뒤 슬라이스 + Bg = B.reshape(num_groups, q_dim + kv_dim + kv_dim, r) + Bq = Bg[:, :q_dim, :].reshape(num_groups * q_dim, r) + Bk = Bg[:, q_dim:q_dim+kv_dim, :].reshape(num_groups * kv_dim, r) + Bv = Bg[:, q_dim+kv_dim:, :].reshape(num_groups * kv_dim, r) + + misc = {k: v for k, v in tensors.items() if k not in ("lora_A.weight", "lora_B.weight")} + + # q_proj + q_prefix = prefix.replace(".self_attention.linear_qkv", ".self_attn.q_proj") + push(dst_tensors, f"{q_prefix}.lora_A.weight", A) + push(dst_tensors, f"{q_prefix}.lora_B.weight", Bq) + + # k_proj + k_prefix = prefix.replace(".self_attention.linear_qkv", ".self_attn.k_proj") + push(dst_tensors, f"{k_prefix}.lora_A.weight", A) + push(dst_tensors, f"{k_prefix}.lora_B.weight", Bk) + for k, v in misc.items(): + push(dst_tensors, f"{k_prefix}.{k}", v) + + # v_proj + v_prefix = prefix.replace(".self_attention.linear_qkv", ".self_attn.v_proj") + push(dst_tensors, f"{v_prefix}.lora_A.weight", A) + push(dst_tensors, f"{v_prefix}.lora_B.weight", Bv) + for k, v in misc.items(): + push(dst_tensors, f"{v_prefix}.{k}", v) + + def convert_mla_attention(prefix, tensors): + """ + Multi-Latent Attention (MLA) LoRA 변환 + + mcore -> HF: + linear_q_down_proj -> q_a_proj + linear_q_up_proj -> q_b_proj + linear_kv_down_proj -> kv_a_proj_with_mqa + linear_kv_up_proj -> kv_b_proj + """ + # q_proj (down -> a, up -> b) + if ".linear_q_down_proj" in prefix: + new_prefix = prefix.replace(".linear_q_down_proj", ".q_a_proj") + for local, T in tensors.items(): + push(dst_tensors, f"{new_prefix}.{local}", T) + elif ".linear_q_up_proj" in prefix: + new_prefix = prefix.replace(".linear_q_up_proj", ".q_b_proj") + for local, T in tensors.items(): + push(dst_tensors, f"{new_prefix}.{local}", T) + elif ".linear_kv_down_proj" in prefix: + new_prefix = prefix.replace(".linear_kv_down_proj", ".kv_a_proj_with_mqa") + for local, T in tensors.items(): + push(dst_tensors, f"{new_prefix}.{local}", T) + elif ".linear_kv_up_proj" in prefix: + new_prefix = prefix.replace(".linear_kv_up_proj", ".kv_b_proj") + for local, T in tensors.items(): + push(dst_tensors, f"{new_prefix}.{local}", T) + + def convert_mlp_linear_fc1(prefix, tensors): + """ + MLP linear_fc1 LoRA를 HF gate_proj, up_proj로 분리 + + mcore: linear_fc1 [gate_up_dim, in_features] + -> HF: gate_proj [gate_dim, in_features], up_proj [up_dim, in_features] + """ + A = tensors.get("lora_A.weight", None) + B = tensors.get("lora_B.weight", None) + if A is None or B is None: + for local, T in tensors.items(): + push(dst_tensors, f"{prefix}.{local}", T) + return + + # gate_up_dim을 gate_dim과 up_dim으로 분리 (보통 1:1 비율) + gate_up_dim = B.shape[0] + gate_dim = gate_up_dim // 2 + up_dim = gate_up_dim - gate_dim + + # B를 gate와 up으로 분리 + B_gate = B[:gate_dim, :] + B_up = B[gate_dim:, :] + + misc = {k: v for k, v in tensors.items() if k not in ("lora_A.weight", "lora_B.weight")} + + # gate_proj + gate_prefix = prefix.replace(".mlp.linear_fc1", ".mlp.gate_proj") + push(dst_tensors, f"{gate_prefix}.lora_A.weight", A) + push(dst_tensors, f"{gate_prefix}.lora_B.weight", B_gate) + for k, v in misc.items(): + push(dst_tensors, f"{gate_prefix}.{k}", v) + + # up_proj + up_prefix = prefix.replace(".mlp.linear_fc1", ".mlp.up_proj") + push(dst_tensors, f"{up_prefix}.lora_A.weight", A) + push(dst_tensors, f"{up_prefix}.lora_B.weight", B_up) + for k, v in misc.items(): + push(dst_tensors, f"{up_prefix}.{k}", v) + + def convert_mlp_linear_fc2(prefix, tensors): + """MLP linear_fc2 LoRA를 HF down_proj로 변환""" + new_prefix = prefix.replace(".mlp.linear_fc2", ".mlp.down_proj") + for local, T in tensors.items(): + push(dst_tensors, f"{new_prefix}.{local}", T) + + def convert_moe_experts(prefix, tensors): + """MoE experts LoRA 변환""" + # experts[expert_idx].linear_fc1 -> experts[expert_idx].gate_proj, up_proj + if ".linear_fc1" in prefix: + convert_mlp_linear_fc1(prefix, tensors) + # experts[expert_idx].linear_fc2 -> experts[expert_idx].down_proj + elif ".linear_fc2" in prefix: + convert_mlp_linear_fc2(prefix, tensors) + + # 모듈별 변환 실행 + for prefix, tensors in bucket.items(): + # Attention 변환 + if ".self_attention.linear_proj" in prefix: + convert_linear_proj(prefix, tensors) + elif ".self_attention.linear_qkv" in prefix: + convert_linear_qkv(prefix, tensors) + # Multi-Latent Attention 변환 + elif any(x in prefix for x in [".linear_q_down_proj", ".linear_q_up_proj", + ".linear_kv_down_proj", ".linear_kv_up_proj"]): + convert_mla_attention(prefix, tensors) + # MLP 변환 + elif ".mlp.linear_fc1" in prefix: + convert_mlp_linear_fc1(prefix, tensors) + elif ".mlp.linear_fc2" in prefix: + convert_mlp_linear_fc2(prefix, tensors) + # MoE experts 변환 (router는 제외) + elif ".experts" in prefix and (".linear_fc1" in prefix or ".linear_fc2" in prefix): + convert_moe_experts(prefix, tensors) + else: + # 알 수 없는 모듈은 그대로 복사 + logger.warning(f"Unknown module pattern: {prefix}") + for local, T in tensors.items(): + push(dst_tensors, f"{prefix}.{local}", T) + + # 변환된 텐서 저장 + save_file(dst_tensors, dst_model, metadata={"format": "pt"}) + logger.info(f"Saved converted LoRA tensors to {dst_model}") + + # adapter_config.json 갱신 + logger.info("Converting adapter config...") + cfg = peft_model.config if isinstance(peft_model.config, dict) else peft_model.config.__dict__.copy() + + tm = cfg.get("target_modules", None) + if tm is not None: + if isinstance(tm, str): + tm = [tm] + new_tm = [] + for t in tm: + if t == "linear_proj": + new_tm.append("o_proj") + elif t in ("linear_qkv", "query_key_value"): + new_tm.extend(["q_proj", "k_proj", "v_proj"]) + elif t == "linear_fc1": + new_tm.extend(["gate_proj", "up_proj"]) + elif t == "linear_fc2": + new_tm.append("down_proj") + elif t == "linear_q_down_proj": + new_tm.append("q_a_proj") + elif t == "linear_q_up_proj": + new_tm.append("q_b_proj") + elif t == "linear_kv_down_proj": + new_tm.append("kv_a_proj_with_mqa") + elif t == "linear_kv_up_proj": + new_tm.append("kv_b_proj") + else: + new_tm.append(t) + cfg["target_modules"] = sorted(set(new_tm)) + + with open(dst_cfg, "w", encoding="utf-8") as f: + json.dump(cfg, f, ensure_ascii=False, indent=2) + logger.info(f"Saved converted adapter config to {dst_cfg}") diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 93f892c2e8..1a002843a0 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -28,6 +28,15 @@ class MegatronModelMeta: extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None + def convert_mcore_lora_to_hf_peft(self, peft_model, mg_model, dst_dir: str, num_groups: int) -> None: + """Megatron Core LoRA 어댑터를 HuggingFace PEFT 형식으로 변환합니다.""" + # only for gpt model type + if self.megatron_model_type != 'gpt': + raise ValueError(f"convert_mcore_lora_to_hf_peft is only supported for gpt model type, but got {self.megatron_model_type}") + + from .gpt.mcore2hf import convert_mcore_lora_to_hf_peft + convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir, num_groups) + def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): megatron_model_type = megatron_model_meta.megatron_model_type diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index aa3202f580..f7cd125d41 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -2,7 +2,7 @@ import math from contextlib import contextmanager -from dataclasses import fields +from dataclasses import fields, asdict from typing import Any, Dict import torch @@ -307,10 +307,25 @@ def convert_mcore2hf(args: ExportArguments) -> None: peft_model = prepare_mcore_model(mg_model) with adapter_state_dict_context(): load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) - logger.info('Merge LoRA...') - mg_model = peft_model.merge_and_unload() + if args.to_hf and not args.merge_lora: + logger.info(f"Saving LoRA adapter to `{args.output_dir}` ...") + assert args.multi_latent_attention is False, "Multi-latent attention is not supported for LoRA conversion." + + peft_model.config = asdict(peft_model.config) # for PEFT <= 0.17.1 + # Megatron Core LoRA를 HuggingFace PEFT 형식으로 변환 + megatron_model_meta.convert_mcore_lora_to_hf_peft( + peft_model=peft_model, + mg_model=mg_model, + dst_dir=args.output_dir, + num_groups=args.num_query_groups if args.group_query_attention else args.num_attention_heads + ) + logger.info("LoRA adapter saved successfully.") + else: + logger.info('Merge LoRA...') + mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') - if args.to_hf: + + if args.to_hf and not args.merge_lora: hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) megatron_model_meta.convert_mcore2hf(hf_model, mg_model) if args.test_convert_precision: From e03e889cd51903374674dc1e26379c147fe948d8 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Mon, 20 Oct 2025 20:13:21 +0900 Subject: [PATCH 2/9] modified control logic --- swift/megatron/utils/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index f7cd125d41..2ae880595b 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -325,7 +325,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') - if args.to_hf and not args.merge_lora: + if args.to_hf and args.merge_lora: hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) megatron_model_meta.convert_mcore2hf(hf_model, mg_model) if args.test_convert_precision: From e84dcc80639068a4eabd676b032557fcecead580 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Mon, 20 Oct 2025 21:29:17 +0900 Subject: [PATCH 3/9] modified minor bugs --- swift/megatron/model/gpt/mcore2hf_lora.py | 4 ++-- swift/megatron/utils/convert.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py index fd65e8bdf8..71c079dfa0 100644 --- a/swift/megatron/model/gpt/mcore2hf_lora.py +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -33,7 +33,7 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups # Megatron Core의 attention 모듈에서 shape 추론 # 첫 번째 레이어의 attention 모듈 사용 - first_layer = mg_language_model.layers[0] + first_layer = mg_language_model.decoder.layers[0] if hasattr(first_layer, 'self_attention'): attn_module = first_layer.self_attention else: @@ -302,5 +302,5 @@ def convert_moe_experts(prefix, tensors): cfg["target_modules"] = sorted(set(new_tm)) with open(dst_cfg, "w", encoding="utf-8") as f: - json.dump(cfg, f, ensure_ascii=False, indent=2) + json.dump(cfg, f, ensure_ascii=False, indent=2, default=str) logger.info(f"Saved converted adapter config to {dst_cfg}") diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index 2ae880595b..d73614ae77 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -309,7 +309,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) if args.to_hf and not args.merge_lora: logger.info(f"Saving LoRA adapter to `{args.output_dir}` ...") - assert args.multi_latent_attention is False, "Multi-latent attention is not supported for LoRA conversion." + assert megatron_args.multi_latent_attention is False, "Multi-latent attention is not supported for LoRA conversion." peft_model.config = asdict(peft_model.config) # for PEFT <= 0.17.1 # Megatron Core LoRA를 HuggingFace PEFT 형식으로 변환 @@ -317,9 +317,10 @@ def convert_mcore2hf(args: ExportArguments) -> None: peft_model=peft_model, mg_model=mg_model, dst_dir=args.output_dir, - num_groups=args.num_query_groups if args.group_query_attention else args.num_attention_heads + num_groups=megatron_args.num_query_groups if megatron_args.group_query_attention else megatron_args.num_attention_heads ) logger.info("LoRA adapter saved successfully.") + return else: logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() From 8750db2d97f5ff30b7e2958af2c81da659d90a52 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Mon, 20 Oct 2025 23:50:08 +0900 Subject: [PATCH 4/9] modified key mapping --- swift/megatron/model/gpt/mcore2hf_lora.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py index 71c079dfa0..5975962c55 100644 --- a/swift/megatron/model/gpt/mcore2hf_lora.py +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -4,6 +4,7 @@ import json import os from collections import OrderedDict +from dataclasses import asdict from safetensors.torch import save_file from swift.utils import get_logger @@ -273,7 +274,7 @@ def convert_moe_experts(prefix, tensors): # adapter_config.json 갱신 logger.info("Converting adapter config...") - cfg = peft_model.config if isinstance(peft_model.config, dict) else peft_model.config.__dict__.copy() + cfg = peft_model.peft_config['default'] if isinstance(peft_model.peft_config['default'], dict) else asdict(peft_model.peft_config['default']) tm = cfg.get("target_modules", None) if tm is not None: @@ -303,4 +304,6 @@ def convert_moe_experts(prefix, tensors): with open(dst_cfg, "w", encoding="utf-8") as f: json.dump(cfg, f, ensure_ascii=False, indent=2, default=str) + + logger.info(f"cfg: {cfg}") logger.info(f"Saved converted adapter config to {dst_cfg}") From bbe153aee5064c98f20e7a877962b719b014f8b9 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Tue, 21 Oct 2025 15:30:14 +0900 Subject: [PATCH 5/9] change comments to english --- swift/megatron/model/gpt/mcore2hf.py | 2 +- swift/megatron/model/gpt/mcore2hf_lora.py | 72 +++++++++++++---------- swift/megatron/model/register.py | 2 +- swift/megatron/utils/convert.py | 2 +- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index b76e291e61..117151c7a8 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -129,6 +129,6 @@ def convert_mcore2hf(hf_model, mg_model): def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups: int) -> None: - """Megatron Core LoRA 어댑터를 HuggingFace PEFT 형식으로 변환합니다.""" + """Convert Megatron Core LoRA adapter to HuggingFace PEFT format.""" from .mcore2hf_lora import convert_mcore_lora_to_hf_peft as _convert_mcore_lora_to_hf_peft _convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir, num_groups) diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py index 5975962c55..b3a19d6bdd 100644 --- a/swift/megatron/model/gpt/mcore2hf_lora.py +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -14,7 +14,7 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups: int) -> None: """ - Megatron Core LoRA 어댑터를 HuggingFace PEFT 형식으로 변환합니다. + Convert Megatron Core LoRA adapter to HuggingFace PEFT format. Args: peft_model: Megatron Core PEFTModel @@ -28,27 +28,26 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups logger.info(f"Converting Megatron Core LoRA to HF PEFT format at {dst_dir}") - # Megatron Core 모델에서 shape 추론 + # Extract shape information from Megatron Core model logger.info("Extracting shape information from Megatron Core model...") mg_language_model = mg_model.language_model if hasattr(mg_model, 'language_model') else mg_model - - # Megatron Core의 attention 모듈에서 shape 추론 - # 첫 번째 레이어의 attention 모듈 사용 + # Extract shape information from Megatron Core attention module + # Use attention module from the first layer first_layer = mg_language_model.decoder.layers[0] if hasattr(first_layer, 'self_attention'): attn_module = first_layer.self_attention else: attn_module = first_layer.attention - # Megatron Core의 attention shape 추론 + # Extract attention shape from Megatron Core if hasattr(attn_module, 'linear_qkv'): - # fused qkv의 경우 + # For fused qkv case qkv_weight = attn_module.linear_qkv.weight out_features, in_features = qkv_weight.shape q_dim = out_features // (num_groups * 3) # q, k, v 각각 kv_dim = q_dim else: - # 분리된 q, k, v의 경우 + # For separated q, k, v case q_weight = attn_module.linear_q.weight k_weight = attn_module.linear_k.weight v_weight = attn_module.linear_v.weight @@ -62,24 +61,33 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups logger.info(f"Shape inference: num_groups={num_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}") - # peft_model의 state_dict에서 모듈 단위 버킷화 + # Bucketize modules from peft_model state_dict logger.info("Extracting LoRA weights from loaded PEFTModel...") bucket = {} # prefix -> {local_name: tensor} state_dict = peft_model.state_dict() for fullkey, tensor in state_dict.items(): - # adapter 관련 키만 처리 + # Process only adapter-related keys if "lora_A" not in fullkey and "lora_B" not in fullkey: continue parts = fullkey.split(".") - local = ".".join(parts[-2:]) # e.g., lora_A.weight - prefix = ".".join(parts[:-2]) # e.g., ...linear_qkv + + # Parse key considering .default.weight format + if len(parts) >= 2 and parts[-2] == "default": + # e.g., lora_A.default.weight -> lora_A.weight + local = f"{parts[-3]}.{parts[-1]}" # default.weight + prefix = ".".join(parts[:-3]) # e.g., ...linear_qkv + else: + # Original logic: e.g., lora_A.weight + local = ".".join(parts[-2:]) # e.g., lora_A.weight + prefix = ".".join(parts[:-2]) # e.g., ...linear_qkv + bucket.setdefault(prefix, {})[local] = tensor.cpu() dst_tensors = OrderedDict() def push(dst, key, tensor): - """텐서를 저장용으로 독립 복사 + 연속 메모리 보장""" + """Create independent copy of tensor for saving + ensure contiguous memory""" t = tensor.detach().clone().contiguous() if key in dst: raise ValueError(f"Duplicate key: {key}") @@ -90,7 +98,7 @@ def push(dst, key, tensor): dst[key] = t def remap_key_for_peft(key: str) -> str: - """키를 HuggingFace PEFT 형식으로 변환""" + """Convert key to HuggingFace PEFT format""" # 1) decoder → model key = key.replace(".decoder.layers.", ".model.layers.") # 2) self_attention → self_attn @@ -108,10 +116,10 @@ def convert_linear_proj(prefix, tensors): def convert_linear_qkv(prefix, tensors): """ - Megatron Core fused qkv LoRA를 HF q_proj, k_proj, v_proj로 분리 + Split Megatron Core fused qkv LoRA into HF q_proj, k_proj, v_proj mcore: - A: [r, in_features] (공유) + A: [r, in_features] (shared) B: [num_groups*(q_dim+kv_dim+kv_dim), r] -> HF: q_proj: A=[r,in], B=[num_groups*q_dim, r] @@ -121,7 +129,7 @@ def convert_linear_qkv(prefix, tensors): A = tensors.get("lora_A.weight", None) B = tensors.get("lora_B.weight", None) if A is None or B is None: - # 핵심 가중치 없으면 원키로 pass + # If core weights are missing, pass through with original key for local, T in tensors.items(): push(dst_tensors, f"{prefix}.{local}", T) return @@ -134,7 +142,7 @@ def convert_linear_qkv(prefix, tensors): expected_out = num_groups * (q_dim + kv_dim + kv_dim) assert out_B == expected_out, f"Fused B out({out_B}) != expected({expected_out})" - # [num_groups, (q_dim+kv_dim+kv_dim), r]로 reshape한 뒤 슬라이스 + # Reshape to [num_groups, (q_dim+kv_dim+kv_dim), r] then slice Bg = B.reshape(num_groups, q_dim + kv_dim + kv_dim, r) Bq = Bg[:, :q_dim, :].reshape(num_groups * q_dim, r) Bk = Bg[:, q_dim:q_dim+kv_dim, :].reshape(num_groups * kv_dim, r) @@ -163,7 +171,7 @@ def convert_linear_qkv(prefix, tensors): def convert_mla_attention(prefix, tensors): """ - Multi-Latent Attention (MLA) LoRA 변환 + Multi-Latent Attention (MLA) LoRA conversion mcore -> HF: linear_q_down_proj -> q_a_proj @@ -191,7 +199,7 @@ def convert_mla_attention(prefix, tensors): def convert_mlp_linear_fc1(prefix, tensors): """ - MLP linear_fc1 LoRA를 HF gate_proj, up_proj로 분리 + Split MLP linear_fc1 LoRA into HF gate_proj, up_proj mcore: linear_fc1 [gate_up_dim, in_features] -> HF: gate_proj [gate_dim, in_features], up_proj [up_dim, in_features] @@ -203,12 +211,12 @@ def convert_mlp_linear_fc1(prefix, tensors): push(dst_tensors, f"{prefix}.{local}", T) return - # gate_up_dim을 gate_dim과 up_dim으로 분리 (보통 1:1 비율) + # Split gate_up_dim into gate_dim and up_dim (usually 1:1 ratio) gate_up_dim = B.shape[0] gate_dim = gate_up_dim // 2 up_dim = gate_up_dim - gate_dim - # B를 gate와 up으로 분리 + # Split B into gate and up B_gate = B[:gate_dim, :] B_up = B[gate_dim:, :] @@ -229,13 +237,13 @@ def convert_mlp_linear_fc1(prefix, tensors): push(dst_tensors, f"{up_prefix}.{k}", v) def convert_mlp_linear_fc2(prefix, tensors): - """MLP linear_fc2 LoRA를 HF down_proj로 변환""" + """Convert MLP linear_fc2 LoRA to HF down_proj""" new_prefix = prefix.replace(".mlp.linear_fc2", ".mlp.down_proj") for local, T in tensors.items(): push(dst_tensors, f"{new_prefix}.{local}", T) def convert_moe_experts(prefix, tensors): - """MoE experts LoRA 변환""" + """MoE experts LoRA conversion""" # experts[expert_idx].linear_fc1 -> experts[expert_idx].gate_proj, up_proj if ".linear_fc1" in prefix: convert_mlp_linear_fc1(prefix, tensors) @@ -243,36 +251,36 @@ def convert_moe_experts(prefix, tensors): elif ".linear_fc2" in prefix: convert_mlp_linear_fc2(prefix, tensors) - # 모듈별 변환 실행 + # Execute conversion by module for prefix, tensors in bucket.items(): - # Attention 변환 + # Attention conversion if ".self_attention.linear_proj" in prefix: convert_linear_proj(prefix, tensors) elif ".self_attention.linear_qkv" in prefix: convert_linear_qkv(prefix, tensors) - # Multi-Latent Attention 변환 + # Multi-Latent Attention conversion elif any(x in prefix for x in [".linear_q_down_proj", ".linear_q_up_proj", ".linear_kv_down_proj", ".linear_kv_up_proj"]): convert_mla_attention(prefix, tensors) - # MLP 변환 + # MLP conversion elif ".mlp.linear_fc1" in prefix: convert_mlp_linear_fc1(prefix, tensors) elif ".mlp.linear_fc2" in prefix: convert_mlp_linear_fc2(prefix, tensors) - # MoE experts 변환 (router는 제외) + # MoE experts conversion (excluding router) elif ".experts" in prefix and (".linear_fc1" in prefix or ".linear_fc2" in prefix): convert_moe_experts(prefix, tensors) else: - # 알 수 없는 모듈은 그대로 복사 + # Copy unknown modules as-is logger.warning(f"Unknown module pattern: {prefix}") for local, T in tensors.items(): push(dst_tensors, f"{prefix}.{local}", T) - # 변환된 텐서 저장 + # Save converted tensors save_file(dst_tensors, dst_model, metadata={"format": "pt"}) logger.info(f"Saved converted LoRA tensors to {dst_model}") - # adapter_config.json 갱신 + # Update adapter_config.json logger.info("Converting adapter config...") cfg = peft_model.peft_config['default'] if isinstance(peft_model.peft_config['default'], dict) else asdict(peft_model.peft_config['default']) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 1a002843a0..c6c9e687d3 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -29,7 +29,7 @@ class MegatronModelMeta: extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None def convert_mcore_lora_to_hf_peft(self, peft_model, mg_model, dst_dir: str, num_groups: int) -> None: - """Megatron Core LoRA 어댑터를 HuggingFace PEFT 형식으로 변환합니다.""" + """Convert Megatron Core LoRA adapter to HuggingFace PEFT format.""" # only for gpt model type if self.megatron_model_type != 'gpt': raise ValueError(f"convert_mcore_lora_to_hf_peft is only supported for gpt model type, but got {self.megatron_model_type}") diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index d73614ae77..d1eee10151 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -312,7 +312,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: assert megatron_args.multi_latent_attention is False, "Multi-latent attention is not supported for LoRA conversion." peft_model.config = asdict(peft_model.config) # for PEFT <= 0.17.1 - # Megatron Core LoRA를 HuggingFace PEFT 형식으로 변환 + # Convert Megatron Core LoRA to HuggingFace PEFT format megatron_model_meta.convert_mcore_lora_to_hf_peft( peft_model=peft_model, mg_model=mg_model, From 3d124b892a200198fa1657e47304eb0a3d1b5b86 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Wed, 22 Oct 2025 00:43:43 +0900 Subject: [PATCH 6/9] add test code and convert logic changed (\w use hf model) --- swift/megatron/model/gpt/mcore2hf.py | 4 +- swift/megatron/model/gpt/mcore2hf_lora.py | 66 ++---- swift/megatron/model/register.py | 4 +- swift/megatron/utils/convert.py | 5 +- tests/megatron/test_lora_export.py | 260 ++++++++++++++++++++++ 5 files changed, 292 insertions(+), 47 deletions(-) create mode 100644 tests/megatron/test_lora_export.py diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index 117151c7a8..95a68bf17d 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -128,7 +128,7 @@ def convert_mcore2hf(hf_model, mg_model): set_layer_state(args, mg_model, hf_model.model, layer_idx) -def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups: int) -> None: +def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, num_groups: int) -> None: """Convert Megatron Core LoRA adapter to HuggingFace PEFT format.""" from .mcore2hf_lora import convert_mcore_lora_to_hf_peft as _convert_mcore_lora_to_hf_peft - _convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir, num_groups) + _convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir, num_groups) diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py index b3a19d6bdd..4a2f6cc941 100644 --- a/swift/megatron/model/gpt/mcore2hf_lora.py +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -12,15 +12,16 @@ logger = get_logger() -def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups: int) -> None: +def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, num_query_groups: int) -> None: """ Convert Megatron Core LoRA adapter to HuggingFace PEFT format. Args: peft_model: Megatron Core PEFTModel mg_model: loaded Megatron Core Model (for shape) + hf_model: HuggingFace model (required for shape extraction) dst_dir: Dir path to saving HuggingFace PEFT - num_groups: number of Attention group + num_query_groups: number of Attention group """ os.makedirs(dst_dir, exist_ok=True) dst_model = os.path.join(dst_dir, "adapter_model.safetensors") @@ -28,38 +29,19 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir: str, num_groups logger.info(f"Converting Megatron Core LoRA to HF PEFT format at {dst_dir}") - # Extract shape information from Megatron Core model - logger.info("Extracting shape information from Megatron Core model...") - mg_language_model = mg_model.language_model if hasattr(mg_model, 'language_model') else mg_model - # Extract shape information from Megatron Core attention module - # Use attention module from the first layer - first_layer = mg_language_model.decoder.layers[0] - if hasattr(first_layer, 'self_attention'): - attn_module = first_layer.self_attention - else: - attn_module = first_layer.attention + # Extract shape information from HuggingFace model + logger.info("Extracting shape information from HuggingFace model...") + attn0 = hf_model.model.layers[0].self_attn - # Extract attention shape from Megatron Core - if hasattr(attn_module, 'linear_qkv'): - # For fused qkv case - qkv_weight = attn_module.linear_qkv.weight - out_features, in_features = qkv_weight.shape - q_dim = out_features // (num_groups * 3) # q, k, v 각각 - kv_dim = q_dim - else: - # For separated q, k, v case - q_weight = attn_module.linear_q.weight - k_weight = attn_module.linear_k.weight - v_weight = attn_module.linear_v.weight - q_out, in_features = q_weight.shape - k_out, _ = k_weight.shape - v_out, _ = v_weight.shape - - q_dim = q_out // num_groups - kv_dim = k_out // num_groups - assert v_out // num_groups == kv_dim, "k/v group out dim mismatch" + q_out, in_features = attn0.q_proj.weight.shape # [out, in] + k_out, _ = attn0.k_proj.weight.shape + v_out, _ = attn0.v_proj.weight.shape + + q_dim = q_out // num_query_groups + kv_dim = k_out // num_query_groups + assert v_out // num_query_groups == kv_dim, "k/v group out dim mismatch" - logger.info(f"Shape inference: num_groups={num_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}") + logger.info(f"Shape extraction: num_query_groups={num_query_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}") # Bucketize modules from peft_model state_dict logger.info("Extracting LoRA weights from loaded PEFTModel...") @@ -120,11 +102,11 @@ def convert_linear_qkv(prefix, tensors): mcore: A: [r, in_features] (shared) - B: [num_groups*(q_dim+kv_dim+kv_dim), r] + B: [num_query_groups*(q_dim+kv_dim+kv_dim), r] -> HF: - q_proj: A=[r,in], B=[num_groups*q_dim, r] - k_proj: A=[r,in], B=[num_groups*kv_dim, r] - v_proj: A=[r,in], B=[num_groups*kv_dim, r] + q_proj: A=[r,in], B=[num_query_groups*q_dim, r] + k_proj: A=[r,in], B=[num_query_groups*kv_dim, r] + v_proj: A=[r,in], B=[num_query_groups*kv_dim, r] """ A = tensors.get("lora_A.weight", None) B = tensors.get("lora_B.weight", None) @@ -139,14 +121,14 @@ def convert_linear_qkv(prefix, tensors): assert rB == r, f"LoRA rank mismatch: A={r}, B={rB}" assert in_A == in_features, f"in_features mismatch: A={in_A}, base={in_features}" - expected_out = num_groups * (q_dim + kv_dim + kv_dim) + expected_out = num_query_groups * (q_dim + kv_dim + kv_dim) assert out_B == expected_out, f"Fused B out({out_B}) != expected({expected_out})" - # Reshape to [num_groups, (q_dim+kv_dim+kv_dim), r] then slice - Bg = B.reshape(num_groups, q_dim + kv_dim + kv_dim, r) - Bq = Bg[:, :q_dim, :].reshape(num_groups * q_dim, r) - Bk = Bg[:, q_dim:q_dim+kv_dim, :].reshape(num_groups * kv_dim, r) - Bv = Bg[:, q_dim+kv_dim:, :].reshape(num_groups * kv_dim, r) + # Reshape to [num_query_groups, (q_dim+kv_dim+kv_dim), r] then slice + Bg = B.reshape(num_query_groups, q_dim + kv_dim + kv_dim, r) + Bq = Bg[:, :q_dim, :].reshape(num_query_groups * q_dim, r) + Bk = Bg[:, q_dim:q_dim+kv_dim, :].reshape(num_query_groups * kv_dim, r) + Bv = Bg[:, q_dim+kv_dim:, :].reshape(num_query_groups * kv_dim, r) misc = {k: v for k, v in tensors.items() if k not in ("lora_A.weight", "lora_B.weight")} diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index c6c9e687d3..6c1ddb3f7b 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -28,14 +28,14 @@ class MegatronModelMeta: extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None - def convert_mcore_lora_to_hf_peft(self, peft_model, mg_model, dst_dir: str, num_groups: int) -> None: + def convert_mcore_lora_to_hf_peft(self, peft_model, mg_model, hf_model, dst_dir: str, num_groups: int) -> None: """Convert Megatron Core LoRA adapter to HuggingFace PEFT format.""" # only for gpt model type if self.megatron_model_type != 'gpt': raise ValueError(f"convert_mcore_lora_to_hf_peft is only supported for gpt model type, but got {self.megatron_model_type}") from .gpt.mcore2hf import convert_mcore_lora_to_hf_peft - convert_mcore_lora_to_hf_peft(peft_model, mg_model, dst_dir, num_groups) + convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir, num_groups) def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index d1eee10151..fe2b98abd1 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -303,6 +303,9 @@ def convert_mcore2hf(args: ExportArguments) -> None: if megatron_args.load is None: raise ValueError('Please specify `--mcore_model`.') load_checkpoint([mg_model], None, None, strict=True) + if args.to_hf: + hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) + if megatron_args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model) with adapter_state_dict_context(): @@ -316,6 +319,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: megatron_model_meta.convert_mcore_lora_to_hf_peft( peft_model=peft_model, mg_model=mg_model, + hf_model=hf_model, dst_dir=args.output_dir, num_groups=megatron_args.num_query_groups if megatron_args.group_query_attention else megatron_args.num_attention_heads ) @@ -327,7 +331,6 @@ def convert_mcore2hf(args: ExportArguments) -> None: logger.info('Megatron model created successfully.') if args.to_hf and args.merge_lora: - hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) megatron_model_meta.convert_mcore2hf(hf_model, mg_model) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) diff --git a/tests/megatron/test_lora_export.py b/tests/megatron/test_lora_export.py new file mode 100644 index 0000000000..73ff443c94 --- /dev/null +++ b/tests/megatron/test_lora_export.py @@ -0,0 +1,260 @@ +import os +import math +import torch +import torch.nn as nn +from dataclasses import dataclass +from typing import Dict, Tuple, List, Optional + +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import PeftModel + +# ------------------------------ +# Configuration +# ------------------------------ +MERGED_MODEL = "/path/to/merged/model" +BASE_MODEL = "/path/to/base/model" +ADAPTER_DIR = "/path/to/adapter/directory" +TOKENIZER = BASE_MODEL + +# GPU assignment (0-based) +GPU_MERGED = 0 +GPU_PEFT = 1 + +# Tolerance +ATOL = 1e-5 +RTOL = 1e-4 + +DTYPE = torch.float16 +MAX_NEW_TOKENS = 128 +PAD_AS_EOS = True + +PROMPTS = [ + "User: Please explain the definition of CPU.\nAssistant: ", + "Translate the following sentence to English: 'The wind is so refreshing today.'", +] + + +# ----------------------------- +# Utilities +# ----------------------------- +def pin_to_single_gpu(model_name: str, gpu_index: int, dtype=torch.bfloat16): + device_map = {"": gpu_index} + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=dtype, + device_map=device_map + ).eval() + return model + +def attach_peft_on_gpu(base_model_name: str, adapter_dir: str, gpu_index: int, dtype=torch.bfloat16): + device_map = {"": gpu_index} + base = AutoModelForCausalLM.from_pretrained( + base_model_name, + torch_dtype=dtype, + device_map=device_map + ).eval() + model = PeftModel.from_pretrained(base, adapter_dir).eval() + return model + +def make_inputs(tokenizer: AutoTokenizer, seq_len: int = 32, batch_size: int = 2): + texts = [f"Verification sample #{i}." for i in range(batch_size)] + enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=seq_len) + return enc + +@dataclass +class DiffStat: + max_abs: float + mean_abs: float + cos_sim: float + +def tensor_diff(a: torch.Tensor, b: torch.Tensor) -> DiffStat: + a = a.detach().float().cpu().reshape(-1) + b = b.detach().float().cpu().reshape(-1) + max_abs = (a - b).abs().max().item() + mean_abs = (a - b).abs().mean().item() + cos = float("nan") if a.norm() == 0 or b.norm() == 0 else torch.nn.functional.cosine_similarity(a, b, dim=0).item() + return DiffStat(max_abs, mean_abs, cos) + +def report(stat: DiffStat, tag: str): + ok = stat.max_abs <= ATOL + RTOL * max(1.0, stat.mean_abs) + print(f"[{tag}] max|Δ|={stat.max_abs:.3e} mean|Δ|={stat.mean_abs:.3e} cos={stat.cos_sim:.6f} -> {'OK' if ok else 'MISMATCH'}") + return ok + +# ----------------------------- +# 1) End-to-end comparison +# ----------------------------- +@torch.inference_mode() +def compare_e2e(merged, peft_model, tokenizer): + batch_cpu = make_inputs(tokenizer) + batch_m = {k: v.to(f"cuda:{GPU_MERGED}") for k, v in batch_cpu.items()} + batch_p = {k: v.to(f"cuda:{GPU_PEFT}") for k, v in batch_cpu.items()} + + out_m = merged(**batch_m, output_hidden_states=True) + out_p = peft_model(**batch_p, output_hidden_states=True) + + ok1 = report(tensor_diff(out_m.logits, out_p.logits), "logits") + ok2 = report(tensor_diff(out_m.hidden_states[-1], out_p.hidden_states[-1]), "hidden_states[-1]") + return ok1 and ok2 + +# ----------------------------- +# 2) Module-wise effective weight comparison +# (W_eff = W0 + Σ(B@A)*scale) +# ----------------------------- +def find_linear_modules(model: nn.Module, suffixes=("q_proj","k_proj","v_proj","o_proj")) -> Dict[str, nn.Linear]: + out = {} + for name, mod in model.named_modules(): + if isinstance(mod, nn.Linear) and any(name.endswith(f".self_attn.{suf}") for suf in suffixes): + out[name] = mod + return out + +def peft_has_lora(mod: nn.Module) -> bool: + return hasattr(mod, "lora_A") and hasattr(mod, "lora_B") + +def peft_effective_weight(linear_mod: nn.Module) -> torch.Tensor: + W0 = linear_mod.weight.detach().float().cpu() + if not peft_has_lora(linear_mod): + return W0 + + delta = torch.zeros_like(W0) + fan_in_fan_out = getattr(linear_mod, "fan_in_fan_out", False) + + for name, A in linear_mod.lora_A.items(): + if name not in linear_mod.lora_B: + continue + B = linear_mod.lora_B[name] + + A_w = A.weight.detach().float().cpu() + B_w = B.weight.detach().float().cpu() + + BA = (A_w.t() @ B_w.t()).t() if fan_in_fan_out else (B_w @ A_w) + + # scaling + if hasattr(linear_mod, "scaling"): + scale = float(linear_mod.scaling[name]) + else: + r = A_w.shape[0] + alpha = getattr(linear_mod, "lora_alpha", r) + scale = float(alpha) / float(r) + + delta += BA * scale + + return W0 + delta + +def _resolve_in_peft(peft_model: nn.Module, merged_name: str) -> Optional[nn.Module]: + """ + Based on the merged module name, sequentially try possible prefixes in the PEFT wrapper. + """ + candidates = [ + merged_name, + f"base_model.{merged_name}", + f"base_model.model.{merged_name}", + f"base_model.model.model.{merged_name}", + ] + peft_named = dict(peft_model.named_modules()) + for cand in candidates: + if cand in peft_named: + return peft_named[cand] + return None + +@torch.inference_mode() +def compare_weights(merged, peft_model): + ok_all = True + merged_lin = find_linear_modules(merged) + + for name, m_lin in merged_lin.items(): + p_lin = _resolve_in_peft(peft_model, name) + if p_lin is None: + print(f"[SKIP] Cannot resolve in PEFT: {name}") + ok_all = False + continue + + W_merged = m_lin.weight.detach().float().cpu() + W_peft_eff = peft_effective_weight(p_lin) + + ok = report(tensor_diff(W_merged, W_peft_eff), f"Weights::{name}") + ok_all = ok_all and ok + + return ok_all + +# ----------------------------- +# Generation comparison +# ----------------------------- +def load_models(): + tok = AutoTokenizer.from_pretrained(TOKENIZER, use_fast=True) + if PAD_AS_EOS and tok.pad_token_id is None: + tok.pad_token = tok.eos_token + tok.padding_side = "left" # Improved batch padding stability in causal LM + + merged = AutoModelForCausalLM.from_pretrained( + MERGED_MODEL, torch_dtype=DTYPE, device_map={"": GPU_MERGED} + ).eval() + + base = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, torch_dtype=DTYPE, device_map={"": GPU_PEFT} + ).eval() + + peft_model = PeftModel.from_pretrained(base, ADAPTER_DIR).eval() + model = peft_model.merge_and_unload().eval() + + return tok, merged, model + +@torch.inference_mode() +def run_generate(model, tok, prompts, device, **gen_kwargs): + enc = tok(prompts, return_tensors="pt", padding=True) + enc = {k: v.to(f"cuda:{device}") for k, v in enc.items()} + out = model.generate( + **enc, + **gen_kwargs, + return_dict_in_generate=True, + output_scores=True + ) + texts = tok.batch_decode(out.sequences, skip_special_tokens=True) + return out, texts + +def compare_texts(a_list, b_list): + ok = True + for i, (a, b) in enumerate(zip(a_list, b_list)): + same = (a == b) + ok &= same + tag = "SAME " if same else "DIFF*" + print(f"[{tag}] sample#{i}\n--- merged ---\n{a}\n--- base+peft ---\n{b}\n") + return ok + +def main(): + torch.manual_seed(0) + + tok, merged, peft_model = load_models() + + # ===== (Optional) End-to-end logits/hidden comparison ===== + print("\n=== (0) End-to-end tensors (sanity) ===") + _ = compare_e2e(merged, peft_model, tok) + + # ===== 1) Deterministic verification (greedy) ===== + greedy_args = dict( + do_sample=False, + max_new_tokens=MAX_NEW_TOKENS, + num_beams=1, # ✅ Beam search disabled + repetition_penalty=1.0, # ✅ Keep default value + temperature=None, + top_p=None, top_k=None, + eos_token_id=tok.eos_token_id, + pad_token_id=tok.pad_token_id, + use_cache=True + ) + print("\n=== GREEDY (deterministic) ===") + out_m_g, texts_m_g = run_generate(merged, tok, PROMPTS, GPU_MERGED, **greedy_args) + out_p_g, texts_p_g = run_generate(peft_model, tok, PROMPTS, GPU_PEFT, **greedy_args) + ok_greedy = compare_texts(texts_m_g, texts_p_g) + + # ===== 2) Module-wise effective weight comparison ===== + print("\n=== (2) Module-wise effective weights ===") + ok_w = compare_weights(merged, peft_model) + + # Summary + print("\n=== SUMMARY ===") + print("GREEDY MATCH ✅" if ok_greedy else "GREEDY MISMATCH ❌") + if not ok_greedy: + print("※ Please recheck adapter/key mapping to match from greedy.") + +if __name__ == "__main__": + main() \ No newline at end of file From a4039d05ddc3be0ca3433744310a72869222a2e8 Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Wed, 22 Oct 2025 18:27:59 +0900 Subject: [PATCH 7/9] applied lint --- swift/megatron/model/gpt/config.py | 2 +- swift/megatron/model/gpt/mcore2hf_lora.py | 225 +++++++++++----------- swift/megatron/model/register.py | 8 +- swift/megatron/utils/convert.py | 14 +- tests/megatron/test_lora_export.py | 113 +++++------ 5 files changed, 184 insertions(+), 178 deletions(-) diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py index e779a827b6..431b285340 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/model/gpt/config.py @@ -68,7 +68,7 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: res['mrope_interleaved'] = mrope_interleaved if first_k_dense_replace is not None: - res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}' + res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res['num_layers'] - first_k_dense_replace}' if res.get('moe_router_score_function', 'softmax') == 'sigmoid': res['moe_router_enable_expert_bias'] = True if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res: diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py index 4a2f6cc941..3acfd556e6 100644 --- a/swift/megatron/model/gpt/mcore2hf_lora.py +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -15,91 +15,93 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, num_query_groups: int) -> None: """ Convert Megatron Core LoRA adapter to HuggingFace PEFT format. - + Args: peft_model: Megatron Core PEFTModel mg_model: loaded Megatron Core Model (for shape) hf_model: HuggingFace model (required for shape extraction) dst_dir: Dir path to saving HuggingFace PEFT - num_query_groups: number of Attention group + num_query_groups: number of Attention group """ os.makedirs(dst_dir, exist_ok=True) - dst_model = os.path.join(dst_dir, "adapter_model.safetensors") - dst_cfg = os.path.join(dst_dir, "adapter_config.json") - + dst_model = os.path.join(dst_dir, 'adapter_model.safetensors') + dst_cfg = os.path.join(dst_dir, 'adapter_config.json') + logger.info(f"Converting Megatron Core LoRA to HF PEFT format at {dst_dir}") - + # Extract shape information from HuggingFace model - logger.info("Extracting shape information from HuggingFace model...") + logger.info('Extracting shape information from HuggingFace model...') attn0 = hf_model.model.layers[0].self_attn - + q_out, in_features = attn0.q_proj.weight.shape # [out, in] k_out, _ = attn0.k_proj.weight.shape v_out, _ = attn0.v_proj.weight.shape - + q_dim = q_out // num_query_groups kv_dim = k_out // num_query_groups - assert v_out // num_query_groups == kv_dim, "k/v group out dim mismatch" - - logger.info(f"Shape extraction: num_query_groups={num_query_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}") - + assert v_out // num_query_groups == kv_dim, 'k/v group out dim mismatch' + + logger.info( + f"Shape extraction: num_query_groups={num_query_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}" + ) + # Bucketize modules from peft_model state_dict - logger.info("Extracting LoRA weights from loaded PEFTModel...") + logger.info('Extracting LoRA weights from loaded PEFTModel...') bucket = {} # prefix -> {local_name: tensor} state_dict = peft_model.state_dict() - + for fullkey, tensor in state_dict.items(): # Process only adapter-related keys - if "lora_A" not in fullkey and "lora_B" not in fullkey: + if 'lora_A' not in fullkey and 'lora_B' not in fullkey: continue - parts = fullkey.split(".") - + parts = fullkey.split('.') + # Parse key considering .default.weight format - if len(parts) >= 2 and parts[-2] == "default": + if len(parts) >= 2 and parts[-2] == 'default': # e.g., lora_A.default.weight -> lora_A.weight local = f"{parts[-3]}.{parts[-1]}" # default.weight - prefix = ".".join(parts[:-3]) # e.g., ...linear_qkv + prefix = '.'.join(parts[:-3]) # e.g., ...linear_qkv else: # Original logic: e.g., lora_A.weight - local = ".".join(parts[-2:]) # e.g., lora_A.weight - prefix = ".".join(parts[:-2]) # e.g., ...linear_qkv - + local = '.'.join(parts[-2:]) # e.g., lora_A.weight + prefix = '.'.join(parts[:-2]) # e.g., ...linear_qkv + bucket.setdefault(prefix, {})[local] = tensor.cpu() - + dst_tensors = OrderedDict() - + def push(dst, key, tensor): """Create independent copy of tensor for saving + ensure contiguous memory""" t = tensor.detach().clone().contiguous() if key in dst: raise ValueError(f"Duplicate key: {key}") - if "weight" not in key: + if 'weight' not in key: logger.debug(f"Skipping non-weight key: {key}") return key = remap_key_for_peft(key) dst[key] = t - + def remap_key_for_peft(key: str) -> str: """Convert key to HuggingFace PEFT format""" # 1) decoder → model - key = key.replace(".decoder.layers.", ".model.layers.") + key = key.replace('.decoder.layers.', '.model.layers.') # 2) self_attention → self_attn - key = key.replace(".self_attention.", ".self_attn.") + key = key.replace('.self_attention.', '.self_attn.') # 3) check prefix - if key.startswith("model.layers."): - key = "base_model.model." + key + if key.startswith('model.layers.'): + key = 'base_model.model.' + key return key - + def convert_linear_proj(prefix, tensors): """mcore: ...self_attention.linear_proj -> HF: ...self_attn.o_proj""" - new_prefix = prefix.replace(".self_attention.linear_proj", ".self_attn.o_proj") + new_prefix = prefix.replace('.self_attention.linear_proj', '.self_attn.o_proj') for local, T in tensors.items(): push(dst_tensors, f"{new_prefix}.{local}", T) - + def convert_linear_qkv(prefix, tensors): """ Split Megatron Core fused qkv LoRA into HF q_proj, k_proj, v_proj - + mcore: A: [r, in_features] (shared) B: [num_query_groups*(q_dim+kv_dim+kv_dim), r] @@ -108,53 +110,53 @@ def convert_linear_qkv(prefix, tensors): k_proj: A=[r,in], B=[num_query_groups*kv_dim, r] v_proj: A=[r,in], B=[num_query_groups*kv_dim, r] """ - A = tensors.get("lora_A.weight", None) - B = tensors.get("lora_B.weight", None) + A = tensors.get('lora_A.weight', None) + B = tensors.get('lora_B.weight', None) if A is None or B is None: # If core weights are missing, pass through with original key for local, T in tensors.items(): push(dst_tensors, f"{prefix}.{local}", T) return - + r, in_A = A.shape out_B, rB = B.shape assert rB == r, f"LoRA rank mismatch: A={r}, B={rB}" assert in_A == in_features, f"in_features mismatch: A={in_A}, base={in_features}" - + expected_out = num_query_groups * (q_dim + kv_dim + kv_dim) assert out_B == expected_out, f"Fused B out({out_B}) != expected({expected_out})" - + # Reshape to [num_query_groups, (q_dim+kv_dim+kv_dim), r] then slice Bg = B.reshape(num_query_groups, q_dim + kv_dim + kv_dim, r) Bq = Bg[:, :q_dim, :].reshape(num_query_groups * q_dim, r) - Bk = Bg[:, q_dim:q_dim+kv_dim, :].reshape(num_query_groups * kv_dim, r) - Bv = Bg[:, q_dim+kv_dim:, :].reshape(num_query_groups * kv_dim, r) - - misc = {k: v for k, v in tensors.items() if k not in ("lora_A.weight", "lora_B.weight")} - + Bk = Bg[:, q_dim:q_dim + kv_dim, :].reshape(num_query_groups * kv_dim, r) + Bv = Bg[:, q_dim + kv_dim:, :].reshape(num_query_groups * kv_dim, r) + + misc = {k: v for k, v in tensors.items() if k not in ('lora_A.weight', 'lora_B.weight')} + # q_proj - q_prefix = prefix.replace(".self_attention.linear_qkv", ".self_attn.q_proj") + q_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.q_proj') push(dst_tensors, f"{q_prefix}.lora_A.weight", A) push(dst_tensors, f"{q_prefix}.lora_B.weight", Bq) - + # k_proj - k_prefix = prefix.replace(".self_attention.linear_qkv", ".self_attn.k_proj") + k_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.k_proj') push(dst_tensors, f"{k_prefix}.lora_A.weight", A) push(dst_tensors, f"{k_prefix}.lora_B.weight", Bk) for k, v in misc.items(): push(dst_tensors, f"{k_prefix}.{k}", v) - + # v_proj - v_prefix = prefix.replace(".self_attention.linear_qkv", ".self_attn.v_proj") + v_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.v_proj') push(dst_tensors, f"{v_prefix}.lora_A.weight", A) push(dst_tensors, f"{v_prefix}.lora_B.weight", Bv) for k, v in misc.items(): push(dst_tensors, f"{v_prefix}.{k}", v) - + def convert_mla_attention(prefix, tensors): """ Multi-Latent Attention (MLA) LoRA conversion - + mcore -> HF: linear_q_down_proj -> q_a_proj linear_q_up_proj -> q_b_proj @@ -162,137 +164,138 @@ def convert_mla_attention(prefix, tensors): linear_kv_up_proj -> kv_b_proj """ # q_proj (down -> a, up -> b) - if ".linear_q_down_proj" in prefix: - new_prefix = prefix.replace(".linear_q_down_proj", ".q_a_proj") + if '.linear_q_down_proj' in prefix: + new_prefix = prefix.replace('.linear_q_down_proj', '.q_a_proj') for local, T in tensors.items(): push(dst_tensors, f"{new_prefix}.{local}", T) - elif ".linear_q_up_proj" in prefix: - new_prefix = prefix.replace(".linear_q_up_proj", ".q_b_proj") + elif '.linear_q_up_proj' in prefix: + new_prefix = prefix.replace('.linear_q_up_proj', '.q_b_proj') for local, T in tensors.items(): push(dst_tensors, f"{new_prefix}.{local}", T) - elif ".linear_kv_down_proj" in prefix: - new_prefix = prefix.replace(".linear_kv_down_proj", ".kv_a_proj_with_mqa") + elif '.linear_kv_down_proj' in prefix: + new_prefix = prefix.replace('.linear_kv_down_proj', '.kv_a_proj_with_mqa') for local, T in tensors.items(): push(dst_tensors, f"{new_prefix}.{local}", T) - elif ".linear_kv_up_proj" in prefix: - new_prefix = prefix.replace(".linear_kv_up_proj", ".kv_b_proj") + elif '.linear_kv_up_proj' in prefix: + new_prefix = prefix.replace('.linear_kv_up_proj', '.kv_b_proj') for local, T in tensors.items(): push(dst_tensors, f"{new_prefix}.{local}", T) - + def convert_mlp_linear_fc1(prefix, tensors): """ Split MLP linear_fc1 LoRA into HF gate_proj, up_proj - + mcore: linear_fc1 [gate_up_dim, in_features] -> HF: gate_proj [gate_dim, in_features], up_proj [up_dim, in_features] """ - A = tensors.get("lora_A.weight", None) - B = tensors.get("lora_B.weight", None) + A = tensors.get('lora_A.weight', None) + B = tensors.get('lora_B.weight', None) if A is None or B is None: for local, T in tensors.items(): push(dst_tensors, f"{prefix}.{local}", T) return - + # Split gate_up_dim into gate_dim and up_dim (usually 1:1 ratio) gate_up_dim = B.shape[0] gate_dim = gate_up_dim // 2 up_dim = gate_up_dim - gate_dim - + # Split B into gate and up B_gate = B[:gate_dim, :] B_up = B[gate_dim:, :] - - misc = {k: v for k, v in tensors.items() if k not in ("lora_A.weight", "lora_B.weight")} - + + misc = {k: v for k, v in tensors.items() if k not in ('lora_A.weight', 'lora_B.weight')} + # gate_proj - gate_prefix = prefix.replace(".mlp.linear_fc1", ".mlp.gate_proj") + gate_prefix = prefix.replace('.mlp.linear_fc1', '.mlp.gate_proj') push(dst_tensors, f"{gate_prefix}.lora_A.weight", A) push(dst_tensors, f"{gate_prefix}.lora_B.weight", B_gate) for k, v in misc.items(): push(dst_tensors, f"{gate_prefix}.{k}", v) - + # up_proj - up_prefix = prefix.replace(".mlp.linear_fc1", ".mlp.up_proj") + up_prefix = prefix.replace('.mlp.linear_fc1', '.mlp.up_proj') push(dst_tensors, f"{up_prefix}.lora_A.weight", A) push(dst_tensors, f"{up_prefix}.lora_B.weight", B_up) for k, v in misc.items(): push(dst_tensors, f"{up_prefix}.{k}", v) - + def convert_mlp_linear_fc2(prefix, tensors): """Convert MLP linear_fc2 LoRA to HF down_proj""" - new_prefix = prefix.replace(".mlp.linear_fc2", ".mlp.down_proj") + new_prefix = prefix.replace('.mlp.linear_fc2', '.mlp.down_proj') for local, T in tensors.items(): push(dst_tensors, f"{new_prefix}.{local}", T) - + def convert_moe_experts(prefix, tensors): """MoE experts LoRA conversion""" # experts[expert_idx].linear_fc1 -> experts[expert_idx].gate_proj, up_proj - if ".linear_fc1" in prefix: + if '.linear_fc1' in prefix: convert_mlp_linear_fc1(prefix, tensors) # experts[expert_idx].linear_fc2 -> experts[expert_idx].down_proj - elif ".linear_fc2" in prefix: + elif '.linear_fc2' in prefix: convert_mlp_linear_fc2(prefix, tensors) - + # Execute conversion by module for prefix, tensors in bucket.items(): # Attention conversion - if ".self_attention.linear_proj" in prefix: + if '.self_attention.linear_proj' in prefix: convert_linear_proj(prefix, tensors) - elif ".self_attention.linear_qkv" in prefix: + elif '.self_attention.linear_qkv' in prefix: convert_linear_qkv(prefix, tensors) # Multi-Latent Attention conversion - elif any(x in prefix for x in [".linear_q_down_proj", ".linear_q_up_proj", - ".linear_kv_down_proj", ".linear_kv_up_proj"]): + elif any(x in prefix + for x in ['.linear_q_down_proj', '.linear_q_up_proj', '.linear_kv_down_proj', '.linear_kv_up_proj']): convert_mla_attention(prefix, tensors) # MLP conversion - elif ".mlp.linear_fc1" in prefix: + elif '.mlp.linear_fc1' in prefix: convert_mlp_linear_fc1(prefix, tensors) - elif ".mlp.linear_fc2" in prefix: + elif '.mlp.linear_fc2' in prefix: convert_mlp_linear_fc2(prefix, tensors) # MoE experts conversion (excluding router) - elif ".experts" in prefix and (".linear_fc1" in prefix or ".linear_fc2" in prefix): + elif '.experts' in prefix and ('.linear_fc1' in prefix or '.linear_fc2' in prefix): convert_moe_experts(prefix, tensors) else: # Copy unknown modules as-is logger.warning(f"Unknown module pattern: {prefix}") for local, T in tensors.items(): push(dst_tensors, f"{prefix}.{local}", T) - + # Save converted tensors - save_file(dst_tensors, dst_model, metadata={"format": "pt"}) + save_file(dst_tensors, dst_model, metadata={'format': 'pt'}) logger.info(f"Saved converted LoRA tensors to {dst_model}") - + # Update adapter_config.json - logger.info("Converting adapter config...") - cfg = peft_model.peft_config['default'] if isinstance(peft_model.peft_config['default'], dict) else asdict(peft_model.peft_config['default']) - - tm = cfg.get("target_modules", None) + logger.info('Converting adapter config...') + cfg = peft_model.peft_config['default'] if isinstance(peft_model.peft_config['default'], dict) else asdict( + peft_model.peft_config['default']) + + tm = cfg.get('target_modules', None) if tm is not None: if isinstance(tm, str): tm = [tm] new_tm = [] for t in tm: - if t == "linear_proj": - new_tm.append("o_proj") - elif t in ("linear_qkv", "query_key_value"): - new_tm.extend(["q_proj", "k_proj", "v_proj"]) - elif t == "linear_fc1": - new_tm.extend(["gate_proj", "up_proj"]) - elif t == "linear_fc2": - new_tm.append("down_proj") - elif t == "linear_q_down_proj": - new_tm.append("q_a_proj") - elif t == "linear_q_up_proj": - new_tm.append("q_b_proj") - elif t == "linear_kv_down_proj": - new_tm.append("kv_a_proj_with_mqa") - elif t == "linear_kv_up_proj": - new_tm.append("kv_b_proj") + if t == 'linear_proj': + new_tm.append('o_proj') + elif t in ('linear_qkv', 'query_key_value'): + new_tm.extend(['q_proj', 'k_proj', 'v_proj']) + elif t == 'linear_fc1': + new_tm.extend(['gate_proj', 'up_proj']) + elif t == 'linear_fc2': + new_tm.append('down_proj') + elif t == 'linear_q_down_proj': + new_tm.append('q_a_proj') + elif t == 'linear_q_up_proj': + new_tm.append('q_b_proj') + elif t == 'linear_kv_down_proj': + new_tm.append('kv_a_proj_with_mqa') + elif t == 'linear_kv_up_proj': + new_tm.append('kv_b_proj') else: new_tm.append(t) - cfg["target_modules"] = sorted(set(new_tm)) - - with open(dst_cfg, "w", encoding="utf-8") as f: + cfg['target_modules'] = sorted(set(new_tm)) + + with open(dst_cfg, 'w', encoding='utf-8') as f: json.dump(cfg, f, ensure_ascii=False, indent=2, default=str) logger.info(f"cfg: {cfg}") diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 6c1ddb3f7b..c0412fa21c 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -30,10 +30,12 @@ class MegatronModelMeta: def convert_mcore_lora_to_hf_peft(self, peft_model, mg_model, hf_model, dst_dir: str, num_groups: int) -> None: """Convert Megatron Core LoRA adapter to HuggingFace PEFT format.""" - # only for gpt model type + # only for gpt model type if self.megatron_model_type != 'gpt': - raise ValueError(f"convert_mcore_lora_to_hf_peft is only supported for gpt model type, but got {self.megatron_model_type}") - + raise ValueError( + f"convert_mcore_lora_to_hf_peft is only supported for gpt model type, but got {self.megatron_model_type}" + ) + from .gpt.mcore2hf import convert_mcore_lora_to_hf_peft convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir, num_groups) diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index fe2b98abd1..33e32b1779 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -305,25 +305,25 @@ def convert_mcore2hf(args: ExportArguments) -> None: load_checkpoint([mg_model], None, None, strict=True) if args.to_hf: hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) - + if megatron_args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model) with adapter_state_dict_context(): load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) if args.to_hf and not args.merge_lora: logger.info(f"Saving LoRA adapter to `{args.output_dir}` ...") - assert megatron_args.multi_latent_attention is False, "Multi-latent attention is not supported for LoRA conversion." - - peft_model.config = asdict(peft_model.config) # for PEFT <= 0.17.1 + assert megatron_args.multi_latent_attention is False, 'Multi-latent attention is not supported for LoRA conversion.' + + peft_model.config = asdict(peft_model.config) # for PEFT <= 0.17.1 # Convert Megatron Core LoRA to HuggingFace PEFT format megatron_model_meta.convert_mcore_lora_to_hf_peft( peft_model=peft_model, mg_model=mg_model, hf_model=hf_model, dst_dir=args.output_dir, - num_groups=megatron_args.num_query_groups if megatron_args.group_query_attention else megatron_args.num_attention_heads - ) - logger.info("LoRA adapter saved successfully.") + num_groups=megatron_args.num_query_groups + if megatron_args.group_query_attention else megatron_args.num_attention_heads) + logger.info('LoRA adapter saved successfully.') return else: logger.info('Merge LoRA...') diff --git a/tests/megatron/test_lora_export.py b/tests/megatron/test_lora_export.py index 73ff443c94..b899c794e6 100644 --- a/tests/megatron/test_lora_export.py +++ b/tests/megatron/test_lora_export.py @@ -11,14 +11,14 @@ # ------------------------------ # Configuration # ------------------------------ -MERGED_MODEL = "/path/to/merged/model" -BASE_MODEL = "/path/to/base/model" -ADAPTER_DIR = "/path/to/adapter/directory" -TOKENIZER = BASE_MODEL +MERGED_MODEL = '/path/to/merged/model' +BASE_MODEL = '/path/to/base/model' +ADAPTER_DIR = '/path/to/adapter/directory' +TOKENIZER = BASE_MODEL # GPU assignment (0-based) GPU_MERGED = 0 -GPU_PEFT = 1 +GPU_PEFT = 1 # Tolerance ATOL = 1e-5 @@ -29,7 +29,7 @@ PAD_AS_EOS = True PROMPTS = [ - "User: Please explain the definition of CPU.\nAssistant: ", + 'User: Please explain the definition of CPU.\nAssistant: ', "Translate the following sentence to English: 'The wind is so refreshing today.'", ] @@ -38,48 +38,48 @@ # Utilities # ----------------------------- def pin_to_single_gpu(model_name: str, gpu_index: int, dtype=torch.bfloat16): - device_map = {"": gpu_index} - model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=dtype, - device_map=device_map - ).eval() + device_map = {'': gpu_index} + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map).eval() return model + def attach_peft_on_gpu(base_model_name: str, adapter_dir: str, gpu_index: int, dtype=torch.bfloat16): - device_map = {"": gpu_index} - base = AutoModelForCausalLM.from_pretrained( - base_model_name, - torch_dtype=dtype, - device_map=device_map - ).eval() + device_map = {'': gpu_index} + base = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, device_map=device_map).eval() model = PeftModel.from_pretrained(base, adapter_dir).eval() return model + def make_inputs(tokenizer: AutoTokenizer, seq_len: int = 32, batch_size: int = 2): texts = [f"Verification sample #{i}." for i in range(batch_size)] - enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=seq_len) + enc = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=seq_len) return enc + @dataclass class DiffStat: max_abs: float mean_abs: float cos_sim: float + def tensor_diff(a: torch.Tensor, b: torch.Tensor) -> DiffStat: a = a.detach().float().cpu().reshape(-1) b = b.detach().float().cpu().reshape(-1) max_abs = (a - b).abs().max().item() mean_abs = (a - b).abs().mean().item() - cos = float("nan") if a.norm() == 0 or b.norm() == 0 else torch.nn.functional.cosine_similarity(a, b, dim=0).item() + cos = float('nan') if a.norm() == 0 or b.norm() == 0 else torch.nn.functional.cosine_similarity(a, b, dim=0).item() return DiffStat(max_abs, mean_abs, cos) + def report(stat: DiffStat, tag: str): ok = stat.max_abs <= ATOL + RTOL * max(1.0, stat.mean_abs) - print(f"[{tag}] max|Δ|={stat.max_abs:.3e} mean|Δ|={stat.mean_abs:.3e} cos={stat.cos_sim:.6f} -> {'OK' if ok else 'MISMATCH'}") + print( + f"[{tag}] max|Δ|={stat.max_abs:.3e} mean|Δ|={stat.mean_abs:.3e} cos={stat.cos_sim:.6f} -> {'OK' if ok else 'MISMATCH'}" + ) return ok + # ----------------------------- # 1) End-to-end comparison # ----------------------------- @@ -87,28 +87,31 @@ def report(stat: DiffStat, tag: str): def compare_e2e(merged, peft_model, tokenizer): batch_cpu = make_inputs(tokenizer) batch_m = {k: v.to(f"cuda:{GPU_MERGED}") for k, v in batch_cpu.items()} - batch_p = {k: v.to(f"cuda:{GPU_PEFT}") for k, v in batch_cpu.items()} + batch_p = {k: v.to(f"cuda:{GPU_PEFT}") for k, v in batch_cpu.items()} out_m = merged(**batch_m, output_hidden_states=True) out_p = peft_model(**batch_p, output_hidden_states=True) - ok1 = report(tensor_diff(out_m.logits, out_p.logits), "logits") - ok2 = report(tensor_diff(out_m.hidden_states[-1], out_p.hidden_states[-1]), "hidden_states[-1]") + ok1 = report(tensor_diff(out_m.logits, out_p.logits), 'logits') + ok2 = report(tensor_diff(out_m.hidden_states[-1], out_p.hidden_states[-1]), 'hidden_states[-1]') return ok1 and ok2 + # ----------------------------- # 2) Module-wise effective weight comparison # (W_eff = W0 + Σ(B@A)*scale) # ----------------------------- -def find_linear_modules(model: nn.Module, suffixes=("q_proj","k_proj","v_proj","o_proj")) -> Dict[str, nn.Linear]: +def find_linear_modules(model: nn.Module, suffixes=('q_proj', 'k_proj', 'v_proj', 'o_proj')) -> Dict[str, nn.Linear]: out = {} for name, mod in model.named_modules(): if isinstance(mod, nn.Linear) and any(name.endswith(f".self_attn.{suf}") for suf in suffixes): out[name] = mod return out + def peft_has_lora(mod: nn.Module) -> bool: - return hasattr(mod, "lora_A") and hasattr(mod, "lora_B") + return hasattr(mod, 'lora_A') and hasattr(mod, 'lora_B') + def peft_effective_weight(linear_mod: nn.Module) -> torch.Tensor: W0 = linear_mod.weight.detach().float().cpu() @@ -116,7 +119,7 @@ def peft_effective_weight(linear_mod: nn.Module) -> torch.Tensor: return W0 delta = torch.zeros_like(W0) - fan_in_fan_out = getattr(linear_mod, "fan_in_fan_out", False) + fan_in_fan_out = getattr(linear_mod, 'fan_in_fan_out', False) for name, A in linear_mod.lora_A.items(): if name not in linear_mod.lora_B: @@ -129,17 +132,18 @@ def peft_effective_weight(linear_mod: nn.Module) -> torch.Tensor: BA = (A_w.t() @ B_w.t()).t() if fan_in_fan_out else (B_w @ A_w) # scaling - if hasattr(linear_mod, "scaling"): + if hasattr(linear_mod, 'scaling'): scale = float(linear_mod.scaling[name]) else: r = A_w.shape[0] - alpha = getattr(linear_mod, "lora_alpha", r) + alpha = getattr(linear_mod, 'lora_alpha', r) scale = float(alpha) / float(r) delta += BA * scale return W0 + delta + def _resolve_in_peft(peft_model: nn.Module, merged_name: str) -> Optional[nn.Module]: """ Based on the merged module name, sequentially try possible prefixes in the PEFT wrapper. @@ -156,6 +160,7 @@ def _resolve_in_peft(peft_model: nn.Module, merged_name: str) -> Optional[nn.Mod return peft_named[cand] return None + @torch.inference_mode() def compare_weights(merged, peft_model): ok_all = True @@ -176,6 +181,7 @@ def compare_weights(merged, peft_model): return ok_all + # ----------------------------- # Generation comparison # ----------------------------- @@ -183,78 +189,73 @@ def load_models(): tok = AutoTokenizer.from_pretrained(TOKENIZER, use_fast=True) if PAD_AS_EOS and tok.pad_token_id is None: tok.pad_token = tok.eos_token - tok.padding_side = "left" # Improved batch padding stability in causal LM + tok.padding_side = 'left' # Improved batch padding stability in causal LM - merged = AutoModelForCausalLM.from_pretrained( - MERGED_MODEL, torch_dtype=DTYPE, device_map={"": GPU_MERGED} - ).eval() + merged = AutoModelForCausalLM.from_pretrained(MERGED_MODEL, torch_dtype=DTYPE, device_map={'': GPU_MERGED}).eval() - base = AutoModelForCausalLM.from_pretrained( - BASE_MODEL, torch_dtype=DTYPE, device_map={"": GPU_PEFT} - ).eval() + base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, device_map={'': GPU_PEFT}).eval() peft_model = PeftModel.from_pretrained(base, ADAPTER_DIR).eval() model = peft_model.merge_and_unload().eval() return tok, merged, model + @torch.inference_mode() def run_generate(model, tok, prompts, device, **gen_kwargs): - enc = tok(prompts, return_tensors="pt", padding=True) + enc = tok(prompts, return_tensors='pt', padding=True) enc = {k: v.to(f"cuda:{device}") for k, v in enc.items()} - out = model.generate( - **enc, - **gen_kwargs, - return_dict_in_generate=True, - output_scores=True - ) + out = model.generate(**enc, **gen_kwargs, return_dict_in_generate=True, output_scores=True) texts = tok.batch_decode(out.sequences, skip_special_tokens=True) return out, texts + def compare_texts(a_list, b_list): ok = True for i, (a, b) in enumerate(zip(a_list, b_list)): same = (a == b) ok &= same - tag = "SAME " if same else "DIFF*" + tag = 'SAME ' if same else 'DIFF*' print(f"[{tag}] sample#{i}\n--- merged ---\n{a}\n--- base+peft ---\n{b}\n") return ok + def main(): torch.manual_seed(0) tok, merged, peft_model = load_models() # ===== (Optional) End-to-end logits/hidden comparison ===== - print("\n=== (0) End-to-end tensors (sanity) ===") + print('\n=== (0) End-to-end tensors (sanity) ===') _ = compare_e2e(merged, peft_model, tok) # ===== 1) Deterministic verification (greedy) ===== greedy_args = dict( do_sample=False, max_new_tokens=MAX_NEW_TOKENS, - num_beams=1, # ✅ Beam search disabled + num_beams=1, # ✅ Beam search disabled repetition_penalty=1.0, # ✅ Keep default value temperature=None, - top_p=None, top_k=None, + top_p=None, + top_k=None, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, - use_cache=True - ) - print("\n=== GREEDY (deterministic) ===") + use_cache=True) + print('\n=== GREEDY (deterministic) ===') out_m_g, texts_m_g = run_generate(merged, tok, PROMPTS, GPU_MERGED, **greedy_args) out_p_g, texts_p_g = run_generate(peft_model, tok, PROMPTS, GPU_PEFT, **greedy_args) ok_greedy = compare_texts(texts_m_g, texts_p_g) # ===== 2) Module-wise effective weight comparison ===== - print("\n=== (2) Module-wise effective weights ===") + print('\n=== (2) Module-wise effective weights ===') ok_w = compare_weights(merged, peft_model) # Summary - print("\n=== SUMMARY ===") - print("GREEDY MATCH ✅" if ok_greedy else "GREEDY MISMATCH ❌") + print('\n=== SUMMARY ===') + print('GREEDY MATCH ✅' if ok_greedy else 'GREEDY MISMATCH ❌') if not ok_greedy: - print("※ Please recheck adapter/key mapping to match from greedy.") + print('※ Please recheck adapter/key mapping to match from greedy.') + -if __name__ == "__main__": - main() \ No newline at end of file +if __name__ == '__main__': + main() From e2c96058e1508447f3f0bdfa60f9f08884e05df6 Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Wed, 22 Oct 2025 19:46:02 +0900 Subject: [PATCH 8/9] apply lint --- swift/megatron/model/gpt/mcore2hf_lora.py | 73 ++++++++++++----------- swift/megatron/model/register.py | 2 +- swift/megatron/utils/convert.py | 4 +- tests/megatron/test_lora_export.py | 32 +++++----- 4 files changed, 56 insertions(+), 55 deletions(-) diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py index 3acfd556e6..59fd521867 100644 --- a/swift/megatron/model/gpt/mcore2hf_lora.py +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -1,12 +1,13 @@ # Copyright (c) Kakao Corp. (AI Alignment Team). # Contact: kevin.us@kakaocorp.com -import json import os from collections import OrderedDict from dataclasses import asdict +import json from safetensors.torch import save_file + from swift.utils import get_logger logger = get_logger() @@ -27,7 +28,7 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, dst_model = os.path.join(dst_dir, 'adapter_model.safetensors') dst_cfg = os.path.join(dst_dir, 'adapter_config.json') - logger.info(f"Converting Megatron Core LoRA to HF PEFT format at {dst_dir}") + logger.info(f'Converting Megatron Core LoRA to HF PEFT format at {dst_dir}') # Extract shape information from HuggingFace model logger.info('Extracting shape information from HuggingFace model...') @@ -42,7 +43,7 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, assert v_out // num_query_groups == kv_dim, 'k/v group out dim mismatch' logger.info( - f"Shape extraction: num_query_groups={num_query_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}" + f'Shape extraction: num_query_groups={num_query_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}' ) # Bucketize modules from peft_model state_dict @@ -59,7 +60,7 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, # Parse key considering .default.weight format if len(parts) >= 2 and parts[-2] == 'default': # e.g., lora_A.default.weight -> lora_A.weight - local = f"{parts[-3]}.{parts[-1]}" # default.weight + local = f'{parts[-3]}.{parts[-1]}' # default.weight prefix = '.'.join(parts[:-3]) # e.g., ...linear_qkv else: # Original logic: e.g., lora_A.weight @@ -74,9 +75,9 @@ def push(dst, key, tensor): """Create independent copy of tensor for saving + ensure contiguous memory""" t = tensor.detach().clone().contiguous() if key in dst: - raise ValueError(f"Duplicate key: {key}") + raise ValueError(f'Duplicate key: {key}') if 'weight' not in key: - logger.debug(f"Skipping non-weight key: {key}") + logger.debug(f'Skipping non-weight key: {key}') return key = remap_key_for_peft(key) dst[key] = t @@ -96,7 +97,7 @@ def convert_linear_proj(prefix, tensors): """mcore: ...self_attention.linear_proj -> HF: ...self_attn.o_proj""" new_prefix = prefix.replace('.self_attention.linear_proj', '.self_attn.o_proj') for local, T in tensors.items(): - push(dst_tensors, f"{new_prefix}.{local}", T) + push(dst_tensors, f'{new_prefix}.{local}', T) def convert_linear_qkv(prefix, tensors): """ @@ -115,16 +116,16 @@ def convert_linear_qkv(prefix, tensors): if A is None or B is None: # If core weights are missing, pass through with original key for local, T in tensors.items(): - push(dst_tensors, f"{prefix}.{local}", T) + push(dst_tensors, f'{prefix}.{local}', T) return r, in_A = A.shape out_B, rB = B.shape - assert rB == r, f"LoRA rank mismatch: A={r}, B={rB}" - assert in_A == in_features, f"in_features mismatch: A={in_A}, base={in_features}" + assert rB == r, f'LoRA rank mismatch: A={r}, B={rB}' + assert in_A == in_features, f'in_features mismatch: A={in_A}, base={in_features}' expected_out = num_query_groups * (q_dim + kv_dim + kv_dim) - assert out_B == expected_out, f"Fused B out({out_B}) != expected({expected_out})" + assert out_B == expected_out, f'Fused B out({out_B}) != expected({expected_out})' # Reshape to [num_query_groups, (q_dim+kv_dim+kv_dim), r] then slice Bg = B.reshape(num_query_groups, q_dim + kv_dim + kv_dim, r) @@ -136,22 +137,22 @@ def convert_linear_qkv(prefix, tensors): # q_proj q_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.q_proj') - push(dst_tensors, f"{q_prefix}.lora_A.weight", A) - push(dst_tensors, f"{q_prefix}.lora_B.weight", Bq) + push(dst_tensors, f'{q_prefix}.lora_A.weight', A) + push(dst_tensors, f'{q_prefix}.lora_B.weight', Bq) # k_proj k_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.k_proj') - push(dst_tensors, f"{k_prefix}.lora_A.weight", A) - push(dst_tensors, f"{k_prefix}.lora_B.weight", Bk) + push(dst_tensors, f'{k_prefix}.lora_A.weight', A) + push(dst_tensors, f'{k_prefix}.lora_B.weight', Bk) for k, v in misc.items(): - push(dst_tensors, f"{k_prefix}.{k}", v) + push(dst_tensors, f'{k_prefix}.{k}', v) # v_proj v_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.v_proj') - push(dst_tensors, f"{v_prefix}.lora_A.weight", A) - push(dst_tensors, f"{v_prefix}.lora_B.weight", Bv) + push(dst_tensors, f'{v_prefix}.lora_A.weight', A) + push(dst_tensors, f'{v_prefix}.lora_B.weight', Bv) for k, v in misc.items(): - push(dst_tensors, f"{v_prefix}.{k}", v) + push(dst_tensors, f'{v_prefix}.{k}', v) def convert_mla_attention(prefix, tensors): """ @@ -167,19 +168,19 @@ def convert_mla_attention(prefix, tensors): if '.linear_q_down_proj' in prefix: new_prefix = prefix.replace('.linear_q_down_proj', '.q_a_proj') for local, T in tensors.items(): - push(dst_tensors, f"{new_prefix}.{local}", T) + push(dst_tensors, f'{new_prefix}.{local}', T) elif '.linear_q_up_proj' in prefix: new_prefix = prefix.replace('.linear_q_up_proj', '.q_b_proj') for local, T in tensors.items(): - push(dst_tensors, f"{new_prefix}.{local}", T) + push(dst_tensors, f'{new_prefix}.{local}', T) elif '.linear_kv_down_proj' in prefix: new_prefix = prefix.replace('.linear_kv_down_proj', '.kv_a_proj_with_mqa') for local, T in tensors.items(): - push(dst_tensors, f"{new_prefix}.{local}", T) + push(dst_tensors, f'{new_prefix}.{local}', T) elif '.linear_kv_up_proj' in prefix: new_prefix = prefix.replace('.linear_kv_up_proj', '.kv_b_proj') for local, T in tensors.items(): - push(dst_tensors, f"{new_prefix}.{local}", T) + push(dst_tensors, f'{new_prefix}.{local}', T) def convert_mlp_linear_fc1(prefix, tensors): """ @@ -192,7 +193,7 @@ def convert_mlp_linear_fc1(prefix, tensors): B = tensors.get('lora_B.weight', None) if A is None or B is None: for local, T in tensors.items(): - push(dst_tensors, f"{prefix}.{local}", T) + push(dst_tensors, f'{prefix}.{local}', T) return # Split gate_up_dim into gate_dim and up_dim (usually 1:1 ratio) @@ -208,23 +209,23 @@ def convert_mlp_linear_fc1(prefix, tensors): # gate_proj gate_prefix = prefix.replace('.mlp.linear_fc1', '.mlp.gate_proj') - push(dst_tensors, f"{gate_prefix}.lora_A.weight", A) - push(dst_tensors, f"{gate_prefix}.lora_B.weight", B_gate) + push(dst_tensors, f'{gate_prefix}.lora_A.weight', A) + push(dst_tensors, f'{gate_prefix}.lora_B.weight', B_gate) for k, v in misc.items(): - push(dst_tensors, f"{gate_prefix}.{k}", v) + push(dst_tensors, f'{gate_prefix}.{k}', v) # up_proj up_prefix = prefix.replace('.mlp.linear_fc1', '.mlp.up_proj') - push(dst_tensors, f"{up_prefix}.lora_A.weight", A) - push(dst_tensors, f"{up_prefix}.lora_B.weight", B_up) + push(dst_tensors, f'{up_prefix}.lora_A.weight', A) + push(dst_tensors, f'{up_prefix}.lora_B.weight', B_up) for k, v in misc.items(): - push(dst_tensors, f"{up_prefix}.{k}", v) + push(dst_tensors, f'{up_prefix}.{k}', v) def convert_mlp_linear_fc2(prefix, tensors): """Convert MLP linear_fc2 LoRA to HF down_proj""" new_prefix = prefix.replace('.mlp.linear_fc2', '.mlp.down_proj') for local, T in tensors.items(): - push(dst_tensors, f"{new_prefix}.{local}", T) + push(dst_tensors, f'{new_prefix}.{local}', T) def convert_moe_experts(prefix, tensors): """MoE experts LoRA conversion""" @@ -256,13 +257,13 @@ def convert_moe_experts(prefix, tensors): convert_moe_experts(prefix, tensors) else: # Copy unknown modules as-is - logger.warning(f"Unknown module pattern: {prefix}") + logger.warning(f'Unknown module pattern: {prefix}') for local, T in tensors.items(): - push(dst_tensors, f"{prefix}.{local}", T) + push(dst_tensors, f'{prefix}.{local}', T) # Save converted tensors save_file(dst_tensors, dst_model, metadata={'format': 'pt'}) - logger.info(f"Saved converted LoRA tensors to {dst_model}") + logger.info(f'Saved converted LoRA tensors to {dst_model}') # Update adapter_config.json logger.info('Converting adapter config...') @@ -298,5 +299,5 @@ def convert_moe_experts(prefix, tensors): with open(dst_cfg, 'w', encoding='utf-8') as f: json.dump(cfg, f, ensure_ascii=False, indent=2, default=str) - logger.info(f"cfg: {cfg}") - logger.info(f"Saved converted adapter config to {dst_cfg}") + logger.info(f'cfg: {cfg}') + logger.info(f'Saved converted adapter config to {dst_cfg}') diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index c0412fa21c..cb6c73f3ec 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -33,7 +33,7 @@ def convert_mcore_lora_to_hf_peft(self, peft_model, mg_model, hf_model, dst_dir: # only for gpt model type if self.megatron_model_type != 'gpt': raise ValueError( - f"convert_mcore_lora_to_hf_peft is only supported for gpt model type, but got {self.megatron_model_type}" + f'convert_mcore_lora_to_hf_peft is only supported for gpt model type, but got {self.megatron_model_type}' ) from .gpt.mcore2hf import convert_mcore_lora_to_hf_peft diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index 33e32b1779..dae7a25367 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -2,7 +2,7 @@ import math from contextlib import contextmanager -from dataclasses import fields, asdict +from dataclasses import asdict, fields from typing import Any, Dict import torch @@ -311,7 +311,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: with adapter_state_dict_context(): load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) if args.to_hf and not args.merge_lora: - logger.info(f"Saving LoRA adapter to `{args.output_dir}` ...") + logger.info(f'Saving LoRA adapter to `{args.output_dir}` ...') assert megatron_args.multi_latent_attention is False, 'Multi-latent attention is not supported for LoRA conversion.' peft_model.config = asdict(peft_model.config) # for PEFT <= 0.17.1 diff --git a/tests/megatron/test_lora_export.py b/tests/megatron/test_lora_export.py index b899c794e6..0d1781544d 100644 --- a/tests/megatron/test_lora_export.py +++ b/tests/megatron/test_lora_export.py @@ -1,12 +1,12 @@ -import os import math -import torch -import torch.nn as nn +import os from dataclasses import dataclass -from typing import Dict, Tuple, List, Optional +from typing import Dict, List, Optional, Tuple -from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +import torch.nn as nn from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer # ------------------------------ # Configuration @@ -51,7 +51,7 @@ def attach_peft_on_gpu(base_model_name: str, adapter_dir: str, gpu_index: int, d def make_inputs(tokenizer: AutoTokenizer, seq_len: int = 32, batch_size: int = 2): - texts = [f"Verification sample #{i}." for i in range(batch_size)] + texts = [f'Verification sample #{i}.' for i in range(batch_size)] enc = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=seq_len) return enc @@ -86,8 +86,8 @@ def report(stat: DiffStat, tag: str): @torch.inference_mode() def compare_e2e(merged, peft_model, tokenizer): batch_cpu = make_inputs(tokenizer) - batch_m = {k: v.to(f"cuda:{GPU_MERGED}") for k, v in batch_cpu.items()} - batch_p = {k: v.to(f"cuda:{GPU_PEFT}") for k, v in batch_cpu.items()} + batch_m = {k: v.to(f'cuda:{GPU_MERGED}') for k, v in batch_cpu.items()} + batch_p = {k: v.to(f'cuda:{GPU_PEFT}') for k, v in batch_cpu.items()} out_m = merged(**batch_m, output_hidden_states=True) out_p = peft_model(**batch_p, output_hidden_states=True) @@ -104,7 +104,7 @@ def compare_e2e(merged, peft_model, tokenizer): def find_linear_modules(model: nn.Module, suffixes=('q_proj', 'k_proj', 'v_proj', 'o_proj')) -> Dict[str, nn.Linear]: out = {} for name, mod in model.named_modules(): - if isinstance(mod, nn.Linear) and any(name.endswith(f".self_attn.{suf}") for suf in suffixes): + if isinstance(mod, nn.Linear) and any(name.endswith(f'.self_attn.{suf}') for suf in suffixes): out[name] = mod return out @@ -150,9 +150,9 @@ def _resolve_in_peft(peft_model: nn.Module, merged_name: str) -> Optional[nn.Mod """ candidates = [ merged_name, - f"base_model.{merged_name}", - f"base_model.model.{merged_name}", - f"base_model.model.model.{merged_name}", + f'base_model.{merged_name}', + f'base_model.model.{merged_name}', + f'base_model.model.model.{merged_name}', ] peft_named = dict(peft_model.named_modules()) for cand in candidates: @@ -169,14 +169,14 @@ def compare_weights(merged, peft_model): for name, m_lin in merged_lin.items(): p_lin = _resolve_in_peft(peft_model, name) if p_lin is None: - print(f"[SKIP] Cannot resolve in PEFT: {name}") + print(f'[SKIP] Cannot resolve in PEFT: {name}') ok_all = False continue W_merged = m_lin.weight.detach().float().cpu() W_peft_eff = peft_effective_weight(p_lin) - ok = report(tensor_diff(W_merged, W_peft_eff), f"Weights::{name}") + ok = report(tensor_diff(W_merged, W_peft_eff), f'Weights::{name}') ok_all = ok_all and ok return ok_all @@ -204,7 +204,7 @@ def load_models(): @torch.inference_mode() def run_generate(model, tok, prompts, device, **gen_kwargs): enc = tok(prompts, return_tensors='pt', padding=True) - enc = {k: v.to(f"cuda:{device}") for k, v in enc.items()} + enc = {k: v.to(f'cuda:{device}') for k, v in enc.items()} out = model.generate(**enc, **gen_kwargs, return_dict_in_generate=True, output_scores=True) texts = tok.batch_decode(out.sequences, skip_special_tokens=True) return out, texts @@ -216,7 +216,7 @@ def compare_texts(a_list, b_list): same = (a == b) ok &= same tag = 'SAME ' if same else 'DIFF*' - print(f"[{tag}] sample#{i}\n--- merged ---\n{a}\n--- base+peft ---\n{b}\n") + print(f'[{tag}] sample#{i}\n--- merged ---\n{a}\n--- base+peft ---\n{b}\n') return ok From 51d5e615d22114d8b534115cfaf4ef2a18e95909 Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Wed, 22 Oct 2025 19:56:02 +0900 Subject: [PATCH 9/9] apply precommit run --- swift/megatron/model/gpt/config.py | 2 +- swift/megatron/model/gpt/mcore2hf_lora.py | 6 ++---- swift/megatron/model/register.py | 3 ++- swift/megatron/utils/convert.py | 3 ++- tests/megatron/test_lora_export.py | 7 +++---- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py index 431b285340..e779a827b6 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/model/gpt/config.py @@ -68,7 +68,7 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: res['mrope_interleaved'] = mrope_interleaved if first_k_dense_replace is not None: - res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res['num_layers'] - first_k_dense_replace}' + res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}' if res.get('moe_router_score_function', 'softmax') == 'sigmoid': res['moe_router_enable_expert_bias'] = True if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res: diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py index 59fd521867..b1204af6c8 100644 --- a/swift/megatron/model/gpt/mcore2hf_lora.py +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -42,9 +42,8 @@ def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, kv_dim = k_out // num_query_groups assert v_out // num_query_groups == kv_dim, 'k/v group out dim mismatch' - logger.info( - f'Shape extraction: num_query_groups={num_query_groups}, q_dim={q_dim}, kv_dim={kv_dim}, in_features={in_features}' - ) + logger.info(f'Shape extraction: num_query_groups={num_query_groups}, q_dim={q_dim}, ' + f'kv_dim={kv_dim}, in_features={in_features}') # Bucketize modules from peft_model state_dict logger.info('Extracting LoRA weights from loaded PEFTModel...') @@ -199,7 +198,6 @@ def convert_mlp_linear_fc1(prefix, tensors): # Split gate_up_dim into gate_dim and up_dim (usually 1:1 ratio) gate_up_dim = B.shape[0] gate_dim = gate_up_dim // 2 - up_dim = gate_up_dim - gate_dim # Split B into gate and up B_gate = B[:gate_dim, :] diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index cb6c73f3ec..8eb865da25 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -33,7 +33,8 @@ def convert_mcore_lora_to_hf_peft(self, peft_model, mg_model, hf_model, dst_dir: # only for gpt model type if self.megatron_model_type != 'gpt': raise ValueError( - f'convert_mcore_lora_to_hf_peft is only supported for gpt model type, but got {self.megatron_model_type}' + f'convert_mcore_lora_to_hf_peft is only supported for gpt model type, ' + f'but got {self.megatron_model_type}' ) from .gpt.mcore2hf import convert_mcore_lora_to_hf_peft diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index dae7a25367..c7c9288e73 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -312,7 +312,8 @@ def convert_mcore2hf(args: ExportArguments) -> None: load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) if args.to_hf and not args.merge_lora: logger.info(f'Saving LoRA adapter to `{args.output_dir}` ...') - assert megatron_args.multi_latent_attention is False, 'Multi-latent attention is not supported for LoRA conversion.' + assert megatron_args.multi_latent_attention is False, ( + 'Multi-latent attention is not supported for LoRA conversion.') peft_model.config = asdict(peft_model.config) # for PEFT <= 0.17.1 # Convert Megatron Core LoRA to HuggingFace PEFT format diff --git a/tests/megatron/test_lora_export.py b/tests/megatron/test_lora_export.py index 0d1781544d..327cb3b6e7 100644 --- a/tests/megatron/test_lora_export.py +++ b/tests/megatron/test_lora_export.py @@ -74,9 +74,8 @@ def tensor_diff(a: torch.Tensor, b: torch.Tensor) -> DiffStat: def report(stat: DiffStat, tag: str): ok = stat.max_abs <= ATOL + RTOL * max(1.0, stat.mean_abs) - print( - f"[{tag}] max|Δ|={stat.max_abs:.3e} mean|Δ|={stat.mean_abs:.3e} cos={stat.cos_sim:.6f} -> {'OK' if ok else 'MISMATCH'}" - ) + print(f'[{tag}] max|Δ|={stat.max_abs:.3e} mean|Δ|={stat.mean_abs:.3e} ' + f"cos={stat.cos_sim:.6f} -> {'OK' if ok else 'MISMATCH'}") return ok @@ -248,7 +247,7 @@ def main(): # ===== 2) Module-wise effective weight comparison ===== print('\n=== (2) Module-wise effective weights ===') - ok_w = compare_weights(merged, peft_model) + compare_weights(merged, peft_model) # Summary print('\n=== SUMMARY ===')