Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,181 @@ def supports_lora_conversion(self, adapter_name: str = "default") -> bool:

return self.base_model.supports_lora_conversion()

def get_base_model_state_dict(self) -> dict[str, torch.Tensor]:
"""
Returns the state dict of the base model with the original model keys.

This method extracts the base model's parameters, removing PEFT-specific key modifications and filtering out
adapter-specific parameters (like LoRA matrices).

This is useful when you need to access or save the base model's weights with their original key names.

Returns:
`dict[str, torch.Tensor]`:
The base model's state dict with original keys (without PEFT modifications).

Example:
```python
>>> from transformers import AutoModelForCausalLM
>>> from peft import get_peft_model, LoraConfig

>>> base_model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> original_keys = set(base_model.state_dict().keys())

>>> peft_model = get_peft_model(base_model, LoraConfig(target_modules=["c_attn"]))
>>> base_state_dict = peft_model.get_base_model_state_dict()

>>> # The keys match the original model
>>> assert set(base_state_dict.keys()) == original_keys
```
"""
# for prompt learning methods the base model structure is not modified
if self._is_prompt_learning:
return dict(self.base_model.state_dict())

# Get state dict from the underlying model
state_dict = self.base_model.model.state_dict()

# Collect all adapter prefixes to identify adapter-specific parameters
adapter_prefixes: set[str] = set()
for config in self.peft_config.values():
prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(config.peft_type)
if prefix:
adapter_prefixes.add(prefix)

result: dict[str, torch.Tensor] = {}

for key, value in state_dict.items():
# skip adapter specific params such as .lora_A, .lora_B
is_adapter_param = False
for prefix in adapter_prefixes:
if f".{prefix}" in key or key.startswith(prefix):
is_adapter_param = True
break

if is_adapter_param:
continue

# skip adapter-specific copies in modules_to_save
if ".modules_to_save." in key or ".trainable_tokens_" in key:
continue

# Transform keys to original format by removing PEFT-specific infixes
new_key = key
while ".base_layer." in new_key:
new_key = new_key.replace(".base_layer.", ".", 1)
while ".original_module." in new_key:
new_key = new_key.replace(".original_module.", ".", 1)

result[new_key] = value

return result

def set_base_model_state_dict(
self,
state_dict: dict[str, torch.Tensor],
strict: bool = True,
):
"""
Sets the base model's state dict using original model keys.

This method takes a state dict with original model key names (without PEFT
modifications) and loads it into the base model, automatically handling
the key transformations required by PEFT (such as adding `.base_layer.`
infixes for tuner layers).

This is the counterpart to `get_base_model_state_dict` and is useful for
scenarios like loading base model weights after FSDP wrapping.

Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict with original model keys to load.
strict (`bool`, *optional*, defaults to `True`):
Whether to strictly enforce that the keys in `state_dict` match
the keys expected by the base model. If `True`, raises a
`RuntimeError` if there are missing or unexpected keys.

Returns:
`NamedTuple` with `missing_keys` and `unexpected_keys` fields
(using original key names), similar to `torch.nn.Module.load_state_dict`.

Raises:
RuntimeError: If `strict=True` and there are missing or unexpected keys.

Example:
```python
>>> from transformers import AutoModelForCausalLM
>>> from peft import get_peft_model, LoraConfig
>>> import torch

>>> base_model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> base_weights = base_model.state_dict()
>>> peft_model = get_peft_model(base_model, LoraConfig(target_modules=["c_attn"]))

>>> # Restore original base model weights on the peft wrapped model
>>> result = peft_model.set_base_model_state_dict(base_weights)
```
"""
from collections import namedtuple

_IncompatibleKeys = namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])

if self._is_prompt_learning:
return self.base_model.load_state_dict(state_dict, strict=strict)

current_state_dict = self.base_model.model.state_dict()

adapter_prefixes: set[str] = set()
for config in self.peft_config.values():
prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(config.peft_type)
if prefix:
adapter_prefixes.add(prefix)

original_to_peft_key: dict[str, str] = {}

for peft_key in current_state_dict.keys():
is_adapter_param = False
for prefix in adapter_prefixes:
if f".{prefix}" in peft_key or peft_key.startswith(prefix):
is_adapter_param = True
break

if is_adapter_param:
continue
Comment on lines +1771 to +1777
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not a sufficient filter for methods like VeRA or VB-LoRA that employ weight sharing. This will be covered by the extended tests I suppose.

An alternative approach would be to iterate over all named modules of the model and remove those keys that belong to BaseTunerLayer instances (since the weight shared keys are caught by the prefix matching already in place). But lets see what the tests say first, maybe I'm wrong and everything works fine :)


if ".modules_to_save." in peft_key or ".trainable_tokens_" in peft_key:
continue

original_key = peft_key
while ".base_layer." in original_key:
original_key = original_key.replace(".base_layer.", ".", 1)
while ".original_module." in original_key:
original_key = original_key.replace(".original_module.", ".", 1)

original_to_peft_key[original_key] = peft_key

peft_state_dict: dict[str, torch.Tensor] = {}
unexpected_keys: list[str] = []

for original_key, value in state_dict.items():
if original_key in original_to_peft_key:
peft_state_dict[original_to_peft_key[original_key]] = value
else:
unexpected_keys.append(original_key)

missing_keys = [k for k in original_to_peft_key.keys() if k not in state_dict]

if strict and (missing_keys or unexpected_keys):
error_msgs: list[str] = []
if missing_keys:
error_msgs.append(f"Missing key(s) in state_dict: {missing_keys}")
if unexpected_keys:
error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected_keys}")
raise RuntimeError("Error(s) in loading state_dict:\n\t" + "\n\t".join(error_msgs))

self.base_model.model.load_state_dict(peft_state_dict, strict=False)
return _IncompatibleKeys(missing_keys=missing_keys, unexpected_keys=unexpected_keys)


class PeftModelForSequenceClassification(PeftModel):
"""
Expand Down
Loading