Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0219529
[WIP] ENH: Adapter injection based on state_dict
BenjaminBossan Jul 9, 2025
c8b96b8
Fix small bug
BenjaminBossan Jul 9, 2025
7bcea4c
Refactor low level API tests to use pytest
BenjaminBossan Jul 10, 2025
73f53b1
Add fix for VeRA + direct injection
BenjaminBossan Jul 10, 2025
ae2258a
Add default target_modules for LoHa and LoKr
BenjaminBossan Jul 10, 2025
bac32c6
Add **kwargs to PEFT model classes __init__
BenjaminBossan Jul 10, 2025
28bf441
Add tests, fix a few small issues
BenjaminBossan Jul 10, 2025
699757d
Remove __init__ methods from PEFT models
BenjaminBossan Jul 10, 2025
34a5bc0
Remove VeRA changes
BenjaminBossan Jul 10, 2025
a04ae0d
Add documentation
BenjaminBossan Jul 10, 2025
32a1fb2
FEAT Add SHiRA Adapters (#2584)
kkb-code Jul 14, 2025
571a055
FIX: Prompt learning methods modules_to_save issue (#2646)
BenjaminBossan Jul 14, 2025
8d46c76
FIX Deploy method comp app: error in workflow file (#2645)
BenjaminBossan Jul 14, 2025
d7e9436
FEAT Allow LoRA to target nn.Parameter (#2638)
BenjaminBossan Jul 15, 2025
cc379fe
Update README.md (#2659)
cx-alberto-simoes Jul 21, 2025
47c4a6d
FIX Prefix tuning after transformers PR 38635 (#2662)
BenjaminBossan Jul 22, 2025
e253b59
make method comparison device agnostic, so it can expand to more acce…
yao-matrix Jul 22, 2025
ecfdab8
Update tokenizer parameter in sfttrainer across multiple examples (#2…
gapsong Jul 23, 2025
a36a653
DOC Fix error in code example (#2666)
qgallouedec Jul 24, 2025
83ec3c5
ENH Llama-Adapters support for GPT2 (#2643)
efraimdahl Jul 24, 2025
7d73b01
Method Comparison: Improve formatting/layout of table (#2670)
githubnemo Jul 24, 2025
ecf10f8
ENH: Targeting multiple parameters on the same module (#2665)
BenjaminBossan Jul 24, 2025
1264b9d
Update extending vocab docs (#2669)
githubnemo Jul 25, 2025
1683838
Merge branch 'main' into enh-inject-adapter-based-on-state_dict
BenjaminBossan Jul 28, 2025
0241bc6
Fix for SHiRA
BenjaminBossan Jul 28, 2025
fe1223c
Merge branch 'main' into enh-inject-adapter-based-on-state_dict
BenjaminBossan Jul 30, 2025
3f0c9bd
Reviewer feedback
BenjaminBossan Aug 1, 2025
54c1364
Merge branch 'main' into enh-inject-adapter-based-on-state_dict
BenjaminBossan Aug 1, 2025
53b5da1
Merge branch 'main' into enh-inject-adapter-based-on-state_dict
BenjaminBossan Aug 1, 2025
eacb767
Fix for get_base_model call
BenjaminBossan Aug 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

import torch

Expand Down Expand Up @@ -45,7 +45,11 @@ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:


def inject_adapter_in_model(
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False
peft_config: PeftConfig,
model: torch.nn.Module,
adapter_name: str = "default",
low_cpu_mem_usage: bool = False,
state_dict: Optional[dict[str, torch.Tensor]] = None,
) -> torch.nn.Module:
r"""
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
Expand All @@ -61,6 +65,8 @@ def inject_adapter_in_model(
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.

FIXME
"""
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")
Expand All @@ -73,6 +79,8 @@ def inject_adapter_in_model(
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]

# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
peft_model = tuner_cls(
model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage, state_dict=state_dict
)

return peft_model.model
5 changes: 3 additions & 2 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ class LoraModel(BaseTuner):

prefix: str = "lora_"

def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False, **kwargs) -> None:
# FIXME add kwargs to other model classes
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)

def _check_new_adapter_config(self, config: LoraConfig) -> None:
"""
Expand Down
101 changes: 88 additions & 13 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ class BaseTuner(nn.Module, ABC):

The easiest is to check what is done in the `peft.tuners.lora.LoraModel` class.

FIXME

Attributes:
model (`torch.nn.Module`):
The model to which the adapter tuner layers will be attached.
Expand All @@ -176,6 +178,7 @@ def __init__(
peft_config: Union[PeftConfig, dict[str, PeftConfig]],
adapter_name: str,
low_cpu_mem_usage: bool = False,
state_dict: Optional[dict[str, torch.Tensor]] = None,
) -> None:
super().__init__()

Expand All @@ -200,7 +203,7 @@ def __init__(
self.active_adapter: str | list[str] = adapter_name
self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name)
if peft_config != PeftType.XLORA or peft_config[adapter_name] != PeftType.XLORA:
self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage, state_dict=state_dict)

# Copy the peft_config in the injected model.
self.model.peft_config = self.peft_config
Expand Down Expand Up @@ -427,14 +430,21 @@ def _check_target_module_compatiblity(self, peft_config: PeftConfig, model: nn.M
_check_lora_target_modules_mamba(peft_config, model, target_name)

def inject_adapter(
self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False
self,
model: nn.Module,
adapter_name: str,
autocast_adapter_dtype: bool = True,
low_cpu_mem_usage: bool = False,
state_dict: Optional[dict[str, torch.Tensor]] = None,
) -> None:
r"""
Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the
hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed.

The corresponding PEFT config is directly retrieved from the `peft_config` attribute of the BaseTuner class.

FIXME

Args:
model (`nn.Module`):
The model to be tuned.
Expand All @@ -444,11 +454,17 @@ def inject_adapter(
Whether to autocast the adapter dtype. Defaults to `True`.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.
FIXME

"""
###################################
# PREPARATION OF MODEL AND CONFIG #
###################################

peft_config = self.peft_config[adapter_name]
excluded_modules = []
unmatched_modules = []
targeted_modules_from_peft_config: list[str] = [] # only relevant if state_dict is passed
# Note: If possible, all checks should be performed *at the start of this method*.
# This way, we can raise early if something goes wrong, without leaving the model
# in a bad (half-initialized) state.
Expand Down Expand Up @@ -498,11 +514,18 @@ def inject_adapter(
if len(new_target_modules) < len(peft_config.target_modules):
peft_config.target_modules = new_target_modules

###############################
# MATCHING & CREATING MODULES #
###############################

existing_adapter_map = {}
for key, module in named_modules:
if isinstance(module, BaseTunerLayer):
existing_adapter_map[key] = module

# TODO: check if this the most robust way
state_dict_keys = {k.rsplit(".", 2)[0] for k in state_dict} if state_dict is not None else set()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these state dict keys or module keys? I think it's the latter no? For example, foo.bar.weight. Here the module and submodule are foo and bar. If we were to just obtain state dict keys, a simple state_dict.keys() would have sufficed. So, I think we should consider renaming it.


for key, module in named_modules:
if not key:
continue
Expand All @@ -517,18 +540,66 @@ def inject_adapter(
if excluded_modules and excluded_modules[-1] == key:
continue

result = self._check_target_module_exists(peft_config, key)
if isinstance(result, _ExcludedModule):
excluded_modules.append(key)
elif not result:
unmatched_modules.append(key)
if state_dict is None:
# normal mechanism: match the modules using the peft_config
result = self._check_target_module_exists(peft_config, key)
if isinstance(result, _ExcludedModule):
excluded_modules.append(key)
elif not result:
unmatched_modules.append(key)
else:
self.targeted_module_names.append(key)
parent, target, target_name = _get_submodules(model, key)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config, adapter_name, target, target_name, parent, current_key=key
)
Comment on lines +567 to +580
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers: This is the exact same code as before, just indented by one level.

else:
self.targeted_module_names.append(key)
parent, target, target_name = _get_submodules(model, key)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
# use the state_dict to match modules instead
if key not in state_dict_keys:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case, state_dict is none these keys will be empty. Just flagging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check above if state_dict is None or not, so this is safe here.

unmatched_modules.append(key)
else:
self.targeted_module_names.append(key)
parent, target, target_name = _get_submodules(model, key)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config, adapter_name, target, target_name, parent, current_key=key
)

# still record what would have been matched via the config so that the two results can be compared
if self._check_target_module_exists(peft_config, key):
targeted_modules_from_peft_config.append(key)

####################
# CHECK FOR ERRORS #
####################

if state_dict is not None:
# in case that the state_dict was used as source of truth and it resulted in different outcomes than what
# would have been matched with the PEFT config, warn the user about that.
targeted_set_from_peft_config = set(targeted_modules_from_peft_config)
targeted_set_from_state_dict = set(self.targeted_module_names)
diff_peft_config = targeted_set_from_peft_config - targeted_set_from_state_dict
diff_state_dict = targeted_set_from_state_dict - targeted_set_from_peft_config
error_msg = ""
if diff_peft_config or diff_state_dict:
error_msg = (
"While injecting the PEFT adapters, an inconsistency was discovered between the PEFT config and "
"the provided state_dict. This is not necessarily an issue and can be ignored if this was the "
"intent. "
)
if diff_peft_config:
error_msg += f"The PEFT config contained these additional target modules: {sorted(diff_peft_config)}. "
if diff_state_dict:
error_msg += f"The state_dict contained these additional target modules: {sorted(diff_state_dict)}. "
if error_msg:
# FIXME for debugging purposes, raise here
1/0
warnings.warn(error_msg)

if not self.targeted_module_names and not uses_dummy_target_modules:
if excluded_modules and not unmatched_modules:
Expand Down Expand Up @@ -578,6 +649,10 @@ def inject_adapter(
"See for example https://github.com/huggingface/peft/issues/2018."
)

################
# HOUSEKEEPING #
################

# It's important to set the adapter here (again), because otherwise it can happen that if a 2nd adapter is
# added, and it targets different layer(s) than the first adapter (which is active), then those different
# layers will be activated, which we don't want.
Expand Down
Loading