Skip to content

lora.Linear.weight Parameters Change After Loading Checkpoint in Train Mode Leading to Inconsistent Evaluation Results #190

@sunpihai-up

Description

@sunpihai-up

Issue Summary

While fine-tuning a model by substituting some nn.Linear layers with lora.Linear, I noticed that the evaluation results during training differ from those after loading a checkpoint. More specifically, performing a "load-infer-save" cycle on a checkpoint without conducting any training led to changes in the weight parameters of the lora.Linear layers. Other parameters such as bias and lora_A within lora.Linear did not exhibit this behavior.

Steps to Reproduce

  1. Replace certain nn.Linear layers within the model with lora.Linear for fine-tuning.
  2. Save the entire model state without differentiating between LoRA-specific parameters and pretrained model parameters.
  3. Ensure the model is in train mode.
  4. Load the saved checkpoint using load_state_dict.
  5. Observe that the weight parameter of lora.Linear layers changes after loading, which leads to inconsistent evaluation outcomes.

Root Cause Analysis

The problem appears to occur because when load_state_dict is called while the model is in train mode, it alters the weight parameters of lora.Linear layers. This alteration might be related to the merging and unmerging processes of LoRA parameters with the corresponding pretrained parameters.

Solution Applied

To address this issue, switch the model to eval mode before invoking load_state_dict. This approach ensures that the weight parameters of lora.Linear layers remain stable both before and after loading. Moreover, switching between eval and train modes afterward does not result in anomalies.

Is this behavior expected? If so, it would be helpful to document this behavior or adjust the implementation to prevent confusion among other users.

The following script may help reproduce the issue.

def compare_model_weights(state_dict1, state_dict2):
    # Compare the differences between two state_dict objects 
    # (whether they have the same keys and the same values).
    keys1 = set(state_dict1.keys())
    keys2 = set(state_dict2.keys())

    missing_in_model1 = keys2 - keys1  # Keys present in model2 but not in model1
    missing_in_model2 = keys1 - keys2  # Keys present in model1 but not in model2

    all_match = True

    if missing_in_model1 or missing_in_model2:
        all_match = False
        print("State dict keys do not match.\n")

        if missing_in_model1:
            print(f"Keys missing in model1: {missing_in_model1}\n")

        if missing_in_model2:
            print(f"Keys missing in model2: {missing_in_model2}\n")
        
    common_keys = keys1.intersection(keys2)
    for key in common_keys:
        if not torch.allclose(state_dict1[key], state_dict2[key]):
            all_match = False
            print(f"Weight mismatch found at layer: {key}\n")
            print(f"Model 1 tensor: {state_dict1[key]}\n")
            print(f"Model 2 tensor: {state_dict2[key]}\n")
            print("-" * 80 + "\n")

    if all_match:
            print("All weights match.")
    return all_match


checkpoint_path = "..."
# This checkpoint contains all the weights of the model, 
# including those belonging to LoRA and those of the pre-trained model.
ckp = torch.load(checkpoint_path, map_location="cpu")

# The model contains layers of lora.Linear().
model = Model(...)
# Loading weights in training mode may lead to anomalies.
model.train()
model.load_state_dict(ckp, strict=True)
ckp2= model.state_dict()

# This is very strange. If I execute model.eval(), 
# ckp and ckp2 are different; if I remove it, they are the same.
model.eval()
compare_model_weights(ckp, ckp2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions