Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0219529
[WIP] ENH: Adapter injection based on state_dict
BenjaminBossan Jul 9, 2025
c8b96b8
Fix small bug
BenjaminBossan Jul 9, 2025
7bcea4c
Refactor low level API tests to use pytest
BenjaminBossan Jul 10, 2025
73f53b1
Add fix for VeRA + direct injection
BenjaminBossan Jul 10, 2025
ae2258a
Add default target_modules for LoHa and LoKr
BenjaminBossan Jul 10, 2025
bac32c6
Add **kwargs to PEFT model classes __init__
BenjaminBossan Jul 10, 2025
28bf441
Add tests, fix a few small issues
BenjaminBossan Jul 10, 2025
699757d
Remove __init__ methods from PEFT models
BenjaminBossan Jul 10, 2025
34a5bc0
Remove VeRA changes
BenjaminBossan Jul 10, 2025
a04ae0d
Add documentation
BenjaminBossan Jul 10, 2025
32a1fb2
FEAT Add SHiRA Adapters (#2584)
kkb-code Jul 14, 2025
571a055
FIX: Prompt learning methods modules_to_save issue (#2646)
BenjaminBossan Jul 14, 2025
8d46c76
FIX Deploy method comp app: error in workflow file (#2645)
BenjaminBossan Jul 14, 2025
d7e9436
FEAT Allow LoRA to target nn.Parameter (#2638)
BenjaminBossan Jul 15, 2025
cc379fe
Update README.md (#2659)
cx-alberto-simoes Jul 21, 2025
47c4a6d
FIX Prefix tuning after transformers PR 38635 (#2662)
BenjaminBossan Jul 22, 2025
e253b59
make method comparison device agnostic, so it can expand to more acce…
yao-matrix Jul 22, 2025
ecfdab8
Update tokenizer parameter in sfttrainer across multiple examples (#2…
gapsong Jul 23, 2025
a36a653
DOC Fix error in code example (#2666)
qgallouedec Jul 24, 2025
83ec3c5
ENH Llama-Adapters support for GPT2 (#2643)
efraimdahl Jul 24, 2025
7d73b01
Method Comparison: Improve formatting/layout of table (#2670)
githubnemo Jul 24, 2025
ecf10f8
ENH: Targeting multiple parameters on the same module (#2665)
BenjaminBossan Jul 24, 2025
1264b9d
Update extending vocab docs (#2669)
githubnemo Jul 25, 2025
1683838
Merge branch 'main' into enh-inject-adapter-based-on-state_dict
BenjaminBossan Jul 28, 2025
0241bc6
Fix for SHiRA
BenjaminBossan Jul 28, 2025
fe1223c
Merge branch 'main' into enh-inject-adapter-based-on-state_dict
BenjaminBossan Jul 30, 2025
3f0c9bd
Reviewer feedback
BenjaminBossan Aug 1, 2025
54c1364
Merge branch 'main' into enh-inject-adapter-based-on-state_dict
BenjaminBossan Aug 1, 2025
53b5da1
Merge branch 'main' into enh-inject-adapter-based-on-state_dict
BenjaminBossan Aug 1, 2025
eacb767
Fix for get_base_model call
BenjaminBossan Aug 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion docs/source/developer_guides/low_level_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.

# Adapter injection

With PEFT, you can inject trainable adapters into any `torch` module which allows you to use adapter methods without relying on the modeling classes in PEFT. Currently, PEFT supports injecting [LoRA](../conceptual_guides/adapter#low-rank-adaptation-lora), [AdaLoRA](../conceptual_guides/adapter#adaptive-low-rank-adaptation-adalora), and [IA3](../conceptual_guides/ia3) into models because for these adapters, inplace modification of the model is sufficient for finetuning it.
With PEFT, you can inject trainable adapters into any `torch` module which allows you to use adapter methods without relying on the modeling classes in PEFT. This works for all adapters except for those based on prompt learning (e.g. prefix tuning or p-tuning).

Check the table below to see when you should inject adapters.

Expand Down Expand Up @@ -87,6 +87,28 @@ DummyModel(
)
```

### Injection based on a `state_dict`

Sometimes, it is possible that there is a PEFT adapter checkpoint but the corresponding PEFT config is not known for whatever reason. To inject the PEFT layers for this checkpoint, you would usually have to reverse-engineer the corresponding PEFT config, most notably the `target_modules` argument, based on the `state_dict` from the checkpoint. This can be cumbersome and error prone. To avoid this, it is also possible to call [`inject_adapter_in_model`] and pass the loaded `state_dict` as an argument:

```python
from safetensors.torch import load_file

model = ...
state_dict = load_file(<path-to-safetensors-file>)
lora_config = LoraConfig(...)
model = inject_adapter_in_model(lora_config, model, state_dict=state_dict)
```

In this case, PEFT will use the `state_dict` as reference for which layers to target instead of using the PEFT config. As a user, you don't have to set the exact `target_modules` of the PEFT config for this to work. However, you should still pass a PEFT config of the right type, in this example `LoraConfig`, you can leave the `target_modules` as `None`.

Be aware that this still only creates the uninitialized PEFT layers, the values from the `state_dict` are not used to populate the model weights. To populate the weights, proceed with calling [`set_peft_model_state_dict`] as described below.

⚠️ Note that if there is a mismatch between what is configured in the PEFT config and what is found in the `state_dict`, PEFT will warn you about this. You can ignore the warning if you know that the PEFT config is not correctly specified.

> [!WARNING]
> If the original PEFT adapters was using `target_parameters` instead of `target_modules`, injecting from a `state_dict` will not work correctly. In this case, it is mandatory to use the correct PEFT config for injection.

## Saving the model

To only save the adapter, use the [`get_peft_model_state_dict`] function:
Expand Down
17 changes: 14 additions & 3 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

import torch

Expand Down Expand Up @@ -45,7 +45,11 @@ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:


def inject_adapter_in_model(
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False
peft_config: PeftConfig,
model: torch.nn.Module,
adapter_name: str = "default",
low_cpu_mem_usage: bool = False,
state_dict: Optional[dict[str, torch.Tensor]] = None,
) -> torch.nn.Module:
r"""
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
Expand All @@ -61,6 +65,11 @@ def inject_adapter_in_model(
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.
state_dict (`dict`, *optional*, defaults to `None`)
If a state_dict is passed here, the adapters will be injected based on the entries of the state_dict. This
can be useful when the exact `target_modules` of the PEFT method is unknown, for instance because the
checkpoint was created without meta data. Note that the values from the state_dict are not used, only the
keys are used to determine the correct layers that should be adapted.
"""
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")
Expand All @@ -73,6 +82,8 @@ def inject_adapter_in_model(
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]

# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
peft_model = tuner_cls(
model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage, state_dict=state_dict
)

return peft_model.model
4 changes: 2 additions & 2 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ class AdaLoraModel(LoraModel):

# Note: don't redefine prefix here, it should be inherited from LoraModel

def __init__(self, model, config, adapter_name):
super().__init__(model, config, adapter_name)
def __init__(self, model, config, adapter_name, **kwargs):
super().__init__(model, config, adapter_name, **kwargs)

traininable_mode_counter = 0
for config in self.peft_config.values():
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/boft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ class BOFTModel(BaseTuner):

prefix: str = "boft_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def _check_new_adapter_config(self, config: BOFTConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/c3a/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ class C3AModel(BaseTuner):

prefix: str = "c3a_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def _check_new_adapter_config(self, config: C3AConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/fourierft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ class FourierFTModel(BaseTuner):

prefix: str = "fourierft_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def _check_new_adapter_config(self, config: FourierFTConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ class IA3Model(BaseTuner):

prefix: str = "ia3_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False):
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

@staticmethod
def _create_new_module(ia3_config, adapter_name, target, **kwargs):
# avoid eager bnb import
Expand Down
4 changes: 0 additions & 4 deletions src/peft/tuners/ln_tuning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ class LNTuningModel(BaseTuner):

prefix: str = "ln_tuning_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
# self.adapter_name = adapter_name
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
Expand Down
11 changes: 11 additions & 0 deletions src/peft/tuners/loha/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import nn

from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
from peft.utils import TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING
from peft.utils.other import get_pattern_key

from .layer import Conv2d, Linear, LoHaLayer
Expand Down Expand Up @@ -110,3 +111,13 @@ def _create_and_replace(
else:
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)

@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
11 changes: 11 additions & 0 deletions src/peft/tuners/lokr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import nn

from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
from peft.utils import TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING
from peft.utils.other import get_pattern_key

from .layer import Conv2d, Linear, LoKrLayer
Expand Down Expand Up @@ -112,3 +113,13 @@ def _create_and_replace(
else:
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)

@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
3 changes: 2 additions & 1 deletion src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .config import EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig
from .eva import get_eva_state_dict, initialize_lora_eva_weights
from .gptq import GPTQLoraLinear
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer, ParamWrapper
from .model import LoraModel


Expand All @@ -34,6 +34,7 @@
"LoraLayer",
"LoraModel",
"LoraRuntimeConfig",
"ParamWrapper",
"get_eva_state_dict",
"initialize_lora_eva_weights",
]
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ class LoraModel(BaseTuner):

prefix: str = "lora_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def _check_new_adapter_config(self, config: LoraConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,6 @@ class LycorisTuner(BaseTuner):
prefix: str
layers_mapping: dict[type[torch.nn.Module], type[LycorisLayer]]

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False):
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/oft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ class OFTModel(BaseTuner):

prefix: str = "oft_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def _check_new_adapter_config(self, config: OFTConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/poly/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
class PolyModel(BaseTuner):
prefix: str = "poly_"

def __init__(self, model, config, adapter_name) -> None:
super().__init__(model, config, adapter_name)

@staticmethod
def _check_target_module_exists(poly_config, key):
return check_target_module_exists(poly_config, key)
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/randlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ class RandLoraModel(BaseTuner):

prefix: str = "randlora_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def _find_dim(self, config) -> tuple[int, int]:
"""
Finds the largest input and output dimensions across linear layers that have been wrapped with RandLora.
Expand Down
3 changes: 0 additions & 3 deletions src/peft/tuners/shira/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ class ShiraModel(BaseTuner):

prefix: str = "shira_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def _check_new_adapter_config(self, config: ShiraConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Expand Down
11 changes: 7 additions & 4 deletions src/peft/tuners/trainable_tokens/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@
class TrainableTokensModel(BaseTuner):
prefix: str = "trainable_tokens_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False):
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
Expand All @@ -49,13 +46,19 @@ def _prepare_adapter_config(self, peft_config, model_config):
return peft_config

def inject_adapter(
self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False
self,
model: nn.Module,
adapter_name: str,
autocast_adapter_dtype: bool = True,
low_cpu_mem_usage: bool = False,
**kwargs,
) -> None:
super().inject_adapter(
model=model,
adapter_name=adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
**kwargs,
)

model_config = self.get_model_config(self)
Expand Down
Loading
Loading