diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index eac8023801..95a68bf17d 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + from typing import Optional from megatron.training import get_args @@ -125,3 +126,9 @@ def convert_mcore2hf(hf_model, mg_model): hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) for layer_idx in range(args.num_layers): set_layer_state(args, mg_model, hf_model.model, layer_idx) + + +def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, num_groups: int) -> None: + """Convert Megatron Core LoRA adapter to HuggingFace PEFT format.""" + from .mcore2hf_lora import convert_mcore_lora_to_hf_peft as _convert_mcore_lora_to_hf_peft + _convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir, num_groups) diff --git a/swift/megatron/model/gpt/mcore2hf_lora.py b/swift/megatron/model/gpt/mcore2hf_lora.py new file mode 100644 index 0000000000..b1204af6c8 --- /dev/null +++ b/swift/megatron/model/gpt/mcore2hf_lora.py @@ -0,0 +1,301 @@ +# Copyright (c) Kakao Corp. (AI Alignment Team). +# Contact: kevin.us@kakaocorp.com + +import os +from collections import OrderedDict +from dataclasses import asdict + +import json +from safetensors.torch import save_file + +from swift.utils import get_logger + +logger = get_logger() + + +def convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir: str, num_query_groups: int) -> None: + """ + Convert Megatron Core LoRA adapter to HuggingFace PEFT format. + + Args: + peft_model: Megatron Core PEFTModel + mg_model: loaded Megatron Core Model (for shape) + hf_model: HuggingFace model (required for shape extraction) + dst_dir: Dir path to saving HuggingFace PEFT + num_query_groups: number of Attention group + """ + os.makedirs(dst_dir, exist_ok=True) + dst_model = os.path.join(dst_dir, 'adapter_model.safetensors') + dst_cfg = os.path.join(dst_dir, 'adapter_config.json') + + logger.info(f'Converting Megatron Core LoRA to HF PEFT format at {dst_dir}') + + # Extract shape information from HuggingFace model + logger.info('Extracting shape information from HuggingFace model...') + attn0 = hf_model.model.layers[0].self_attn + + q_out, in_features = attn0.q_proj.weight.shape # [out, in] + k_out, _ = attn0.k_proj.weight.shape + v_out, _ = attn0.v_proj.weight.shape + + q_dim = q_out // num_query_groups + kv_dim = k_out // num_query_groups + assert v_out // num_query_groups == kv_dim, 'k/v group out dim mismatch' + + logger.info(f'Shape extraction: num_query_groups={num_query_groups}, q_dim={q_dim}, ' + f'kv_dim={kv_dim}, in_features={in_features}') + + # Bucketize modules from peft_model state_dict + logger.info('Extracting LoRA weights from loaded PEFTModel...') + bucket = {} # prefix -> {local_name: tensor} + state_dict = peft_model.state_dict() + + for fullkey, tensor in state_dict.items(): + # Process only adapter-related keys + if 'lora_A' not in fullkey and 'lora_B' not in fullkey: + continue + parts = fullkey.split('.') + + # Parse key considering .default.weight format + if len(parts) >= 2 and parts[-2] == 'default': + # e.g., lora_A.default.weight -> lora_A.weight + local = f'{parts[-3]}.{parts[-1]}' # default.weight + prefix = '.'.join(parts[:-3]) # e.g., ...linear_qkv + else: + # Original logic: e.g., lora_A.weight + local = '.'.join(parts[-2:]) # e.g., lora_A.weight + prefix = '.'.join(parts[:-2]) # e.g., ...linear_qkv + + bucket.setdefault(prefix, {})[local] = tensor.cpu() + + dst_tensors = OrderedDict() + + def push(dst, key, tensor): + """Create independent copy of tensor for saving + ensure contiguous memory""" + t = tensor.detach().clone().contiguous() + if key in dst: + raise ValueError(f'Duplicate key: {key}') + if 'weight' not in key: + logger.debug(f'Skipping non-weight key: {key}') + return + key = remap_key_for_peft(key) + dst[key] = t + + def remap_key_for_peft(key: str) -> str: + """Convert key to HuggingFace PEFT format""" + # 1) decoder → model + key = key.replace('.decoder.layers.', '.model.layers.') + # 2) self_attention → self_attn + key = key.replace('.self_attention.', '.self_attn.') + # 3) check prefix + if key.startswith('model.layers.'): + key = 'base_model.model.' + key + return key + + def convert_linear_proj(prefix, tensors): + """mcore: ...self_attention.linear_proj -> HF: ...self_attn.o_proj""" + new_prefix = prefix.replace('.self_attention.linear_proj', '.self_attn.o_proj') + for local, T in tensors.items(): + push(dst_tensors, f'{new_prefix}.{local}', T) + + def convert_linear_qkv(prefix, tensors): + """ + Split Megatron Core fused qkv LoRA into HF q_proj, k_proj, v_proj + + mcore: + A: [r, in_features] (shared) + B: [num_query_groups*(q_dim+kv_dim+kv_dim), r] + -> HF: + q_proj: A=[r,in], B=[num_query_groups*q_dim, r] + k_proj: A=[r,in], B=[num_query_groups*kv_dim, r] + v_proj: A=[r,in], B=[num_query_groups*kv_dim, r] + """ + A = tensors.get('lora_A.weight', None) + B = tensors.get('lora_B.weight', None) + if A is None or B is None: + # If core weights are missing, pass through with original key + for local, T in tensors.items(): + push(dst_tensors, f'{prefix}.{local}', T) + return + + r, in_A = A.shape + out_B, rB = B.shape + assert rB == r, f'LoRA rank mismatch: A={r}, B={rB}' + assert in_A == in_features, f'in_features mismatch: A={in_A}, base={in_features}' + + expected_out = num_query_groups * (q_dim + kv_dim + kv_dim) + assert out_B == expected_out, f'Fused B out({out_B}) != expected({expected_out})' + + # Reshape to [num_query_groups, (q_dim+kv_dim+kv_dim), r] then slice + Bg = B.reshape(num_query_groups, q_dim + kv_dim + kv_dim, r) + Bq = Bg[:, :q_dim, :].reshape(num_query_groups * q_dim, r) + Bk = Bg[:, q_dim:q_dim + kv_dim, :].reshape(num_query_groups * kv_dim, r) + Bv = Bg[:, q_dim + kv_dim:, :].reshape(num_query_groups * kv_dim, r) + + misc = {k: v for k, v in tensors.items() if k not in ('lora_A.weight', 'lora_B.weight')} + + # q_proj + q_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.q_proj') + push(dst_tensors, f'{q_prefix}.lora_A.weight', A) + push(dst_tensors, f'{q_prefix}.lora_B.weight', Bq) + + # k_proj + k_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.k_proj') + push(dst_tensors, f'{k_prefix}.lora_A.weight', A) + push(dst_tensors, f'{k_prefix}.lora_B.weight', Bk) + for k, v in misc.items(): + push(dst_tensors, f'{k_prefix}.{k}', v) + + # v_proj + v_prefix = prefix.replace('.self_attention.linear_qkv', '.self_attn.v_proj') + push(dst_tensors, f'{v_prefix}.lora_A.weight', A) + push(dst_tensors, f'{v_prefix}.lora_B.weight', Bv) + for k, v in misc.items(): + push(dst_tensors, f'{v_prefix}.{k}', v) + + def convert_mla_attention(prefix, tensors): + """ + Multi-Latent Attention (MLA) LoRA conversion + + mcore -> HF: + linear_q_down_proj -> q_a_proj + linear_q_up_proj -> q_b_proj + linear_kv_down_proj -> kv_a_proj_with_mqa + linear_kv_up_proj -> kv_b_proj + """ + # q_proj (down -> a, up -> b) + if '.linear_q_down_proj' in prefix: + new_prefix = prefix.replace('.linear_q_down_proj', '.q_a_proj') + for local, T in tensors.items(): + push(dst_tensors, f'{new_prefix}.{local}', T) + elif '.linear_q_up_proj' in prefix: + new_prefix = prefix.replace('.linear_q_up_proj', '.q_b_proj') + for local, T in tensors.items(): + push(dst_tensors, f'{new_prefix}.{local}', T) + elif '.linear_kv_down_proj' in prefix: + new_prefix = prefix.replace('.linear_kv_down_proj', '.kv_a_proj_with_mqa') + for local, T in tensors.items(): + push(dst_tensors, f'{new_prefix}.{local}', T) + elif '.linear_kv_up_proj' in prefix: + new_prefix = prefix.replace('.linear_kv_up_proj', '.kv_b_proj') + for local, T in tensors.items(): + push(dst_tensors, f'{new_prefix}.{local}', T) + + def convert_mlp_linear_fc1(prefix, tensors): + """ + Split MLP linear_fc1 LoRA into HF gate_proj, up_proj + + mcore: linear_fc1 [gate_up_dim, in_features] + -> HF: gate_proj [gate_dim, in_features], up_proj [up_dim, in_features] + """ + A = tensors.get('lora_A.weight', None) + B = tensors.get('lora_B.weight', None) + if A is None or B is None: + for local, T in tensors.items(): + push(dst_tensors, f'{prefix}.{local}', T) + return + + # Split gate_up_dim into gate_dim and up_dim (usually 1:1 ratio) + gate_up_dim = B.shape[0] + gate_dim = gate_up_dim // 2 + + # Split B into gate and up + B_gate = B[:gate_dim, :] + B_up = B[gate_dim:, :] + + misc = {k: v for k, v in tensors.items() if k not in ('lora_A.weight', 'lora_B.weight')} + + # gate_proj + gate_prefix = prefix.replace('.mlp.linear_fc1', '.mlp.gate_proj') + push(dst_tensors, f'{gate_prefix}.lora_A.weight', A) + push(dst_tensors, f'{gate_prefix}.lora_B.weight', B_gate) + for k, v in misc.items(): + push(dst_tensors, f'{gate_prefix}.{k}', v) + + # up_proj + up_prefix = prefix.replace('.mlp.linear_fc1', '.mlp.up_proj') + push(dst_tensors, f'{up_prefix}.lora_A.weight', A) + push(dst_tensors, f'{up_prefix}.lora_B.weight', B_up) + for k, v in misc.items(): + push(dst_tensors, f'{up_prefix}.{k}', v) + + def convert_mlp_linear_fc2(prefix, tensors): + """Convert MLP linear_fc2 LoRA to HF down_proj""" + new_prefix = prefix.replace('.mlp.linear_fc2', '.mlp.down_proj') + for local, T in tensors.items(): + push(dst_tensors, f'{new_prefix}.{local}', T) + + def convert_moe_experts(prefix, tensors): + """MoE experts LoRA conversion""" + # experts[expert_idx].linear_fc1 -> experts[expert_idx].gate_proj, up_proj + if '.linear_fc1' in prefix: + convert_mlp_linear_fc1(prefix, tensors) + # experts[expert_idx].linear_fc2 -> experts[expert_idx].down_proj + elif '.linear_fc2' in prefix: + convert_mlp_linear_fc2(prefix, tensors) + + # Execute conversion by module + for prefix, tensors in bucket.items(): + # Attention conversion + if '.self_attention.linear_proj' in prefix: + convert_linear_proj(prefix, tensors) + elif '.self_attention.linear_qkv' in prefix: + convert_linear_qkv(prefix, tensors) + # Multi-Latent Attention conversion + elif any(x in prefix + for x in ['.linear_q_down_proj', '.linear_q_up_proj', '.linear_kv_down_proj', '.linear_kv_up_proj']): + convert_mla_attention(prefix, tensors) + # MLP conversion + elif '.mlp.linear_fc1' in prefix: + convert_mlp_linear_fc1(prefix, tensors) + elif '.mlp.linear_fc2' in prefix: + convert_mlp_linear_fc2(prefix, tensors) + # MoE experts conversion (excluding router) + elif '.experts' in prefix and ('.linear_fc1' in prefix or '.linear_fc2' in prefix): + convert_moe_experts(prefix, tensors) + else: + # Copy unknown modules as-is + logger.warning(f'Unknown module pattern: {prefix}') + for local, T in tensors.items(): + push(dst_tensors, f'{prefix}.{local}', T) + + # Save converted tensors + save_file(dst_tensors, dst_model, metadata={'format': 'pt'}) + logger.info(f'Saved converted LoRA tensors to {dst_model}') + + # Update adapter_config.json + logger.info('Converting adapter config...') + cfg = peft_model.peft_config['default'] if isinstance(peft_model.peft_config['default'], dict) else asdict( + peft_model.peft_config['default']) + + tm = cfg.get('target_modules', None) + if tm is not None: + if isinstance(tm, str): + tm = [tm] + new_tm = [] + for t in tm: + if t == 'linear_proj': + new_tm.append('o_proj') + elif t in ('linear_qkv', 'query_key_value'): + new_tm.extend(['q_proj', 'k_proj', 'v_proj']) + elif t == 'linear_fc1': + new_tm.extend(['gate_proj', 'up_proj']) + elif t == 'linear_fc2': + new_tm.append('down_proj') + elif t == 'linear_q_down_proj': + new_tm.append('q_a_proj') + elif t == 'linear_q_up_proj': + new_tm.append('q_b_proj') + elif t == 'linear_kv_down_proj': + new_tm.append('kv_a_proj_with_mqa') + elif t == 'linear_kv_up_proj': + new_tm.append('kv_b_proj') + else: + new_tm.append(t) + cfg['target_modules'] = sorted(set(new_tm)) + + with open(dst_cfg, 'w', encoding='utf-8') as f: + json.dump(cfg, f, ensure_ascii=False, indent=2, default=str) + + logger.info(f'cfg: {cfg}') + logger.info(f'Saved converted adapter config to {dst_cfg}') diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 93f892c2e8..8eb865da25 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -28,6 +28,18 @@ class MegatronModelMeta: extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None + def convert_mcore_lora_to_hf_peft(self, peft_model, mg_model, hf_model, dst_dir: str, num_groups: int) -> None: + """Convert Megatron Core LoRA adapter to HuggingFace PEFT format.""" + # only for gpt model type + if self.megatron_model_type != 'gpt': + raise ValueError( + f'convert_mcore_lora_to_hf_peft is only supported for gpt model type, ' + f'but got {self.megatron_model_type}' + ) + + from .gpt.mcore2hf import convert_mcore_lora_to_hf_peft + convert_mcore_lora_to_hf_peft(peft_model, mg_model, hf_model, dst_dir, num_groups) + def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): megatron_model_type = megatron_model_meta.megatron_model_type diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index aa3202f580..c7c9288e73 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -2,7 +2,7 @@ import math from contextlib import contextmanager -from dataclasses import fields +from dataclasses import asdict, fields from typing import Any, Dict import torch @@ -303,15 +303,35 @@ def convert_mcore2hf(args: ExportArguments) -> None: if megatron_args.load is None: raise ValueError('Please specify `--mcore_model`.') load_checkpoint([mg_model], None, None, strict=True) + if args.to_hf: + hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) + if megatron_args.adapter_load is not None: peft_model = prepare_mcore_model(mg_model) with adapter_state_dict_context(): load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) - logger.info('Merge LoRA...') - mg_model = peft_model.merge_and_unload() + if args.to_hf and not args.merge_lora: + logger.info(f'Saving LoRA adapter to `{args.output_dir}` ...') + assert megatron_args.multi_latent_attention is False, ( + 'Multi-latent attention is not supported for LoRA conversion.') + + peft_model.config = asdict(peft_model.config) # for PEFT <= 0.17.1 + # Convert Megatron Core LoRA to HuggingFace PEFT format + megatron_model_meta.convert_mcore_lora_to_hf_peft( + peft_model=peft_model, + mg_model=mg_model, + hf_model=hf_model, + dst_dir=args.output_dir, + num_groups=megatron_args.num_query_groups + if megatron_args.group_query_attention else megatron_args.num_attention_heads) + logger.info('LoRA adapter saved successfully.') + return + else: + logger.info('Merge LoRA...') + mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') - if args.to_hf: - hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) + + if args.to_hf and args.merge_lora: megatron_model_meta.convert_mcore2hf(hf_model, mg_model) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) diff --git a/tests/megatron/test_lora_export.py b/tests/megatron/test_lora_export.py new file mode 100644 index 0000000000..327cb3b6e7 --- /dev/null +++ b/tests/megatron/test_lora_export.py @@ -0,0 +1,260 @@ +import math +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +# ------------------------------ +# Configuration +# ------------------------------ +MERGED_MODEL = '/path/to/merged/model' +BASE_MODEL = '/path/to/base/model' +ADAPTER_DIR = '/path/to/adapter/directory' +TOKENIZER = BASE_MODEL + +# GPU assignment (0-based) +GPU_MERGED = 0 +GPU_PEFT = 1 + +# Tolerance +ATOL = 1e-5 +RTOL = 1e-4 + +DTYPE = torch.float16 +MAX_NEW_TOKENS = 128 +PAD_AS_EOS = True + +PROMPTS = [ + 'User: Please explain the definition of CPU.\nAssistant: ', + "Translate the following sentence to English: 'The wind is so refreshing today.'", +] + + +# ----------------------------- +# Utilities +# ----------------------------- +def pin_to_single_gpu(model_name: str, gpu_index: int, dtype=torch.bfloat16): + device_map = {'': gpu_index} + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map).eval() + return model + + +def attach_peft_on_gpu(base_model_name: str, adapter_dir: str, gpu_index: int, dtype=torch.bfloat16): + device_map = {'': gpu_index} + base = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, device_map=device_map).eval() + model = PeftModel.from_pretrained(base, adapter_dir).eval() + return model + + +def make_inputs(tokenizer: AutoTokenizer, seq_len: int = 32, batch_size: int = 2): + texts = [f'Verification sample #{i}.' for i in range(batch_size)] + enc = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=seq_len) + return enc + + +@dataclass +class DiffStat: + max_abs: float + mean_abs: float + cos_sim: float + + +def tensor_diff(a: torch.Tensor, b: torch.Tensor) -> DiffStat: + a = a.detach().float().cpu().reshape(-1) + b = b.detach().float().cpu().reshape(-1) + max_abs = (a - b).abs().max().item() + mean_abs = (a - b).abs().mean().item() + cos = float('nan') if a.norm() == 0 or b.norm() == 0 else torch.nn.functional.cosine_similarity(a, b, dim=0).item() + return DiffStat(max_abs, mean_abs, cos) + + +def report(stat: DiffStat, tag: str): + ok = stat.max_abs <= ATOL + RTOL * max(1.0, stat.mean_abs) + print(f'[{tag}] max|Δ|={stat.max_abs:.3e} mean|Δ|={stat.mean_abs:.3e} ' + f"cos={stat.cos_sim:.6f} -> {'OK' if ok else 'MISMATCH'}") + return ok + + +# ----------------------------- +# 1) End-to-end comparison +# ----------------------------- +@torch.inference_mode() +def compare_e2e(merged, peft_model, tokenizer): + batch_cpu = make_inputs(tokenizer) + batch_m = {k: v.to(f'cuda:{GPU_MERGED}') for k, v in batch_cpu.items()} + batch_p = {k: v.to(f'cuda:{GPU_PEFT}') for k, v in batch_cpu.items()} + + out_m = merged(**batch_m, output_hidden_states=True) + out_p = peft_model(**batch_p, output_hidden_states=True) + + ok1 = report(tensor_diff(out_m.logits, out_p.logits), 'logits') + ok2 = report(tensor_diff(out_m.hidden_states[-1], out_p.hidden_states[-1]), 'hidden_states[-1]') + return ok1 and ok2 + + +# ----------------------------- +# 2) Module-wise effective weight comparison +# (W_eff = W0 + Σ(B@A)*scale) +# ----------------------------- +def find_linear_modules(model: nn.Module, suffixes=('q_proj', 'k_proj', 'v_proj', 'o_proj')) -> Dict[str, nn.Linear]: + out = {} + for name, mod in model.named_modules(): + if isinstance(mod, nn.Linear) and any(name.endswith(f'.self_attn.{suf}') for suf in suffixes): + out[name] = mod + return out + + +def peft_has_lora(mod: nn.Module) -> bool: + return hasattr(mod, 'lora_A') and hasattr(mod, 'lora_B') + + +def peft_effective_weight(linear_mod: nn.Module) -> torch.Tensor: + W0 = linear_mod.weight.detach().float().cpu() + if not peft_has_lora(linear_mod): + return W0 + + delta = torch.zeros_like(W0) + fan_in_fan_out = getattr(linear_mod, 'fan_in_fan_out', False) + + for name, A in linear_mod.lora_A.items(): + if name not in linear_mod.lora_B: + continue + B = linear_mod.lora_B[name] + + A_w = A.weight.detach().float().cpu() + B_w = B.weight.detach().float().cpu() + + BA = (A_w.t() @ B_w.t()).t() if fan_in_fan_out else (B_w @ A_w) + + # scaling + if hasattr(linear_mod, 'scaling'): + scale = float(linear_mod.scaling[name]) + else: + r = A_w.shape[0] + alpha = getattr(linear_mod, 'lora_alpha', r) + scale = float(alpha) / float(r) + + delta += BA * scale + + return W0 + delta + + +def _resolve_in_peft(peft_model: nn.Module, merged_name: str) -> Optional[nn.Module]: + """ + Based on the merged module name, sequentially try possible prefixes in the PEFT wrapper. + """ + candidates = [ + merged_name, + f'base_model.{merged_name}', + f'base_model.model.{merged_name}', + f'base_model.model.model.{merged_name}', + ] + peft_named = dict(peft_model.named_modules()) + for cand in candidates: + if cand in peft_named: + return peft_named[cand] + return None + + +@torch.inference_mode() +def compare_weights(merged, peft_model): + ok_all = True + merged_lin = find_linear_modules(merged) + + for name, m_lin in merged_lin.items(): + p_lin = _resolve_in_peft(peft_model, name) + if p_lin is None: + print(f'[SKIP] Cannot resolve in PEFT: {name}') + ok_all = False + continue + + W_merged = m_lin.weight.detach().float().cpu() + W_peft_eff = peft_effective_weight(p_lin) + + ok = report(tensor_diff(W_merged, W_peft_eff), f'Weights::{name}') + ok_all = ok_all and ok + + return ok_all + + +# ----------------------------- +# Generation comparison +# ----------------------------- +def load_models(): + tok = AutoTokenizer.from_pretrained(TOKENIZER, use_fast=True) + if PAD_AS_EOS and tok.pad_token_id is None: + tok.pad_token = tok.eos_token + tok.padding_side = 'left' # Improved batch padding stability in causal LM + + merged = AutoModelForCausalLM.from_pretrained(MERGED_MODEL, torch_dtype=DTYPE, device_map={'': GPU_MERGED}).eval() + + base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, device_map={'': GPU_PEFT}).eval() + + peft_model = PeftModel.from_pretrained(base, ADAPTER_DIR).eval() + model = peft_model.merge_and_unload().eval() + + return tok, merged, model + + +@torch.inference_mode() +def run_generate(model, tok, prompts, device, **gen_kwargs): + enc = tok(prompts, return_tensors='pt', padding=True) + enc = {k: v.to(f'cuda:{device}') for k, v in enc.items()} + out = model.generate(**enc, **gen_kwargs, return_dict_in_generate=True, output_scores=True) + texts = tok.batch_decode(out.sequences, skip_special_tokens=True) + return out, texts + + +def compare_texts(a_list, b_list): + ok = True + for i, (a, b) in enumerate(zip(a_list, b_list)): + same = (a == b) + ok &= same + tag = 'SAME ' if same else 'DIFF*' + print(f'[{tag}] sample#{i}\n--- merged ---\n{a}\n--- base+peft ---\n{b}\n') + return ok + + +def main(): + torch.manual_seed(0) + + tok, merged, peft_model = load_models() + + # ===== (Optional) End-to-end logits/hidden comparison ===== + print('\n=== (0) End-to-end tensors (sanity) ===') + _ = compare_e2e(merged, peft_model, tok) + + # ===== 1) Deterministic verification (greedy) ===== + greedy_args = dict( + do_sample=False, + max_new_tokens=MAX_NEW_TOKENS, + num_beams=1, # ✅ Beam search disabled + repetition_penalty=1.0, # ✅ Keep default value + temperature=None, + top_p=None, + top_k=None, + eos_token_id=tok.eos_token_id, + pad_token_id=tok.pad_token_id, + use_cache=True) + print('\n=== GREEDY (deterministic) ===') + out_m_g, texts_m_g = run_generate(merged, tok, PROMPTS, GPU_MERGED, **greedy_args) + out_p_g, texts_p_g = run_generate(peft_model, tok, PROMPTS, GPU_PEFT, **greedy_args) + ok_greedy = compare_texts(texts_m_g, texts_p_g) + + # ===== 2) Module-wise effective weight comparison ===== + print('\n=== (2) Module-wise effective weights ===') + compare_weights(merged, peft_model) + + # Summary + print('\n=== SUMMARY ===') + print('GREEDY MATCH ✅' if ok_greedy else 'GREEDY MISMATCH ❌') + if not ok_greedy: + print('※ Please recheck adapter/key mapping to match from greedy.') + + +if __name__ == '__main__': + main()