Skip to content

copy_missing_tensors_from_source leaks excluded layers when num_hidden_layers is reduced #1615

@yiliu30

Description

@yiliu30

Bug Description

copy_missing_tensors_from_source incorrectly copies layer tensors that were intentionally excluded by reducing num_hidden_layers. The function was designed to recover auxiliary parameters (e.g., MTP layers) that transformers silently omits during save_pretrained, but it cannot distinguish between "accidentally missing" tensors and "intentionally removed" layers.

Root Cause

The detection logic in _is_truly_missing() uses block-prefix matching: if a source tensor's block prefix (e.g., model.layers.2) does not appear in the saved output's block prefixes, the tensor is considered "missing". When num_hidden_layers is reduced from 28 to 2, layers 2–27 have block prefixes absent from the saved output, so all their tensors are flagged as missing and copied.

Impact

When a user slices a model (e.g., for testing or distillation), the quantized output ends up containing the full original weights for all excluded layers in model_extra_tensors.safetensors. The WOQ path even quantizes them with RTN before writing.

For Qwen/Qwen3-0.6B sliced from 28 to 2 layers: 650 tensors (202.7 MB) from layers 2–27 are incorrectly copied — defeating the purpose of the layer reduction and inflating the output size.

Reproducer

"""
Usage:
    pip install auto-round
    python reproduce.py
"""
import os
import shutil
import tempfile

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from auto_round import AutoRound


def main():
    model_name = "Qwen/Qwen3-0.6B"
    num_layers_to_keep = 2
    output_dir = tempfile.mkdtemp(prefix="repro_quantized_")

    print(f"1. Loading {model_name} and slicing to {num_layers_to_keep} layers...")
    model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    original_num_layers = model.config.num_hidden_layers
    model.model.layers = model.model.layers[:num_layers_to_keep]
    model.config.num_hidden_layers = num_layers_to_keep
    print(f"   Original layers: {original_num_layers}, kept: {num_layers_to_keep}")

    print(f"\n2. Running AutoRound quantize_and_save (iters=0, RTN)...")
    autoround = AutoRound(model, tokenizer, bits=4, group_size=128, sym=True, iters=0, nsamples=1)
    autoround.quantize_and_save(output_dir=output_dir, format="auto_round")

    extra_shard = os.path.join(output_dir, "model_extra_tensors.safetensors")
    if os.path.exists(extra_shard):
        from safetensors import safe_open

        with safe_open(extra_shard, framework="pt", device="cpu") as f:
            leaked_keys = sorted(f.keys())
        total_bytes = 0
        with safe_open(extra_shard, framework="pt", device="cpu") as f:
            for key in f.keys():
                t = f.get_tensor(key)
                total_bytes += t.nelement() * t.element_size()
        print(f"\nBUG: {len(leaked_keys)} tensor(s) incorrectly copied ({total_bytes / 1024 / 1024:.1f} MB)")
        for name in leaked_keys[:5]:
            print(f"  - {name}")
        if len(leaked_keys) > 5:
            print(f"  ... and {len(leaked_keys) - 5} more")
    else:
        print("\nOK: No extra tensors copied.")

    shutil.rmtree(output_dir)

if __name__ == "__main__":
    main()

Expected Output

OK: No extra tensors copied.

Actual Output

Found 286 tensor(s) in the source checkpoint that are absent from the saved output (e.g., MTP parameters):
model.layers.[2-27].input_layernorm, model.layers.[2-27].mlp.down_proj, ... Copying them now...
Applying WOQ[RTN] to 182 missing Linear weight(s)...
Successfully wrote 650 missing tensor(s) to 'model_extra_tensors.safetensors'

BUG: 650 tensor(s) incorrectly copied (202.7 MB)
  - model.layers.10.input_layernorm.weight
  - model.layers.10.mlp.down_proj.qweight
  - model.layers.10.mlp.down_proj.qzeros
  - model.layers.10.mlp.down_proj.scales
  - model.layers.10.mlp.gate_proj.qweight
  ... and 645 more

Possible Fix

For each block-prefix group (e.g., model.layers), extract the max layer index present in the saved output and skip any source tensor whose layer index exceeds that max. This avoids coupling to architecture-specific config keys like num_hidden_layers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions