Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4c6d15f
Tests and inital implementation for embed_tokens
romitjain Oct 29, 2025
4b91220
Minor fixes
romitjain Oct 30, 2025
46b803e
Fixed all tests and made updates to logic
romitjain Oct 31, 2025
37b1e06
Nit
romitjain Oct 31, 2025
8388aa8
Added contigious check for export
romitjain Nov 4, 2025
cd6c6d0
Apply suggestion from @BenjaminBossan
romitjain Nov 4, 2025
0cb44e8
Addressed PR comments
romitjain Nov 5, 2025
628ce10
Update src/peft/tuners/lora/model.py
romitjain Nov 7, 2025
602ce10
Update src/peft/tuners/lora/model.py
romitjain Nov 7, 2025
e2d0345
Apply suggestions from code review
romitjain Nov 7, 2025
7880032
Removed redundant change
romitjain Nov 7, 2025
f73af50
Merge branch 'enh/tie-target-modules' of github.com:romitjain/peft in…
romitjain Nov 7, 2025
46cca1e
Handling target_modules as str
romitjain Nov 7, 2025
2267a48
Update src/peft/tuners/tuners_utils.py
romitjain Nov 10, 2025
5d5b8e4
Updated regex matching
romitjain Nov 12, 2025
c7cfe40
Apply suggestion from @BenjaminBossan
romitjain Nov 13, 2025
8294ec7
Added find layer by tensor
romitjain Nov 13, 2025
7370a21
Merge branch 'main' of github.com:romitjain/peft into enh/tie-target-…
romitjain Nov 13, 2025
1da895f
Fixed tests
romitjain Nov 14, 2025
d86ff7d
Nit
romitjain Nov 18, 2025
dc03dd4
Small fix to ensure correct layer name gets saved for target modules
romitjain Nov 19, 2025
c79a64c
Merge branch 'main' of github.com:huggingface/peft into enh/tie-targe…
romitjain Nov 20, 2025
0715451
Merge branch 'main' of github.com:huggingface/peft into enh/tie-targe…
romitjain Dec 15, 2025
dbb0096
Apply suggestions from code review
romitjain Dec 15, 2025
06d4b7f
Merge branch 'enh/tie-target-modules' of github.com:romitjain/peft in…
romitjain Dec 15, 2025
67a71d6
Updated matching logic
romitjain Dec 15, 2025
8889558
Merge branch 'main' of github.com:romitjain/peft into enh/tie-target-…
romitjain Jan 5, 2026
9f7702f
Merge branch 'main' of github.com:huggingface/peft into enh/tie-targe…
romitjain Jan 12, 2026
4d5d681
Merge branch 'main' into enh/tie-target-modules
romitjain Jan 16, 2026
ba4d81f
Merge branch 'main' into enh/tie-target-modules
romitjain Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ class LoraConfig(PeftConfig):
`target_parameters`. As an example, for Llama4, you can pass:
`target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`. Passing a
string for regex matching is not implemented yet.
ensure_weight_tying (`bool`, *optional*)
Whether to tie weights or not after peft initialization. This will ensure that the adapters added to the
tied layers are also tied. This is only applicable for layers passed via `modules_to_save` and
`target_modules`.

"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -670,7 +675,7 @@ class LoraConfig(PeftConfig):
"Whether to tie weights or not after peft initialization. "
"This will ensure that the adapters added to the tied layers "
"are also tied. This is only applicable for layers passed via "
"`modules_to_save`."
"`modules_to_save` and and `target_modules`."
)
},
)
Expand All @@ -695,6 +700,7 @@ def __post_init__(self):

if self.ensure_weight_tying:
self.modules_to_tie = None
self.target_modules_to_tie = None

if isinstance(self.target_parameters, str):
raise TypeError("`target_parameters` must be a list of strings or None.")
Expand Down
14 changes: 14 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def update_layer(
arrow_config: ArrowConfig = None,
qalora_group_size: int = 32,
inference_mode: bool = False,
is_tied: bool = False,
tied_adapters: dict = {},
**kwargs,
):
# collect the kwargs
Expand Down Expand Up @@ -195,6 +197,16 @@ def update_layer(
# Actual trainable parameters
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias)
if is_tied:
if not tied_adapters:
raise RuntimeError("Layer is marked as tied, but tied adapters are not provided")

lora_A_params = tied_adapters["lora_A"]
lora_B_params = tied_adapters["lora_B"]

self.lora_A[adapter_name].weight = torch.nn.Parameter(lora_A_params)
self.lora_B[adapter_name].weight = torch.nn.Parameter(lora_B_params)

self.lora_bias[adapter_name] = lora_bias

if use_rslora:
Expand Down Expand Up @@ -631,6 +643,8 @@ def __init__(
use_alora=use_alora,
lora_bias=lora_bias,
arrow_config=arrow_config,
is_tied=kwargs.get("is_tied", False),
tied_adapters=kwargs.get("tied_adapters"),
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

Expand Down
59 changes: 53 additions & 6 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ def _create_and_replace(
r = lora_config.rank_pattern.get(r_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha)

is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or [])
tied_adapters = {}
if is_tied:
tied_module = self.model.get_input_embeddings()
emb_A = tied_module.lora_embedding_A[adapter_name]
emb_B = tied_module.lora_embedding_B[adapter_name]

tied_adapters = {"lora_A": emb_B.t(), "lora_B": emb_A.t()}

kwargs = {
"r": r,
"lora_alpha": alpha,
Expand All @@ -204,6 +213,8 @@ def _create_and_replace(
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
"parameter_name": parameter_name,
"is_tied": is_tied,
"tied_adapters": tied_adapters,
}

# for torchao merging, we need the get_apply_tensor_subclass from the quantization config
Expand Down Expand Up @@ -249,9 +260,17 @@ def _create_and_replace(
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)

def _replace_module(self, parent, child_name, new_module, child):
self._replace_module(
parent=parent,
child_name=target_name,
new_module=new_module,
child=target,
is_tied=is_tied,
adapter_name=adapter_name,
)

def _replace_module(self, parent, child_name, new_module, child, is_tied, adapter_name):
# override in LoraModel to handle quantized weights properly

setattr(parent, child_name, new_module)
Expand Down Expand Up @@ -806,8 +825,36 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap

return tensors_lora

def _add_modules_to_tie(self, peft_config, tied_weight_keys):
modules_to_save = set(getattr(peft_config, "modules_to_save", []) or [])
missing_keys = set(tied_weight_keys) - modules_to_save
def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]):
"""
Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the
tied layers from `module_to_save`. Maintain a separate set for layers to be tied

Args:
peft_config (LoraConfig): _description_
tied_weight_keys (list[str]): _description_
"""
tied_weight_keys = set(tied_weight_keys)
setattr(peft_config, "modules_to_tie", tied_weight_keys)

modules_to_save = getattr(peft_config, "modules_to_save", []) or []
if "embed_tokens" not in modules_to_save:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the embedding layer has a different name, this won't be correct, right? It's probably still fine for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is true. I am not able to find a way to get the embedding layer name from the model

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could try to find the layer name whose parameter corresponds to model.get_input_embedding() but I'm fine with assuming the name here. Let's just add a comment.

modules_to_save.append("embed_tokens")

for m in tied_weight_keys:
if m in modules_to_save:
modules_to_save.remove(m)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how often this will generate a match. If I understand correctly, tied_weight_keys are fully-qualified keys. So this check will only match if the keys in modules_to_save are also fully-qualified. I don't think this happens often. cc @BenjaminBossan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@githubnemo In the flow, I am unable to find any place where we are converting keys in modules_to_save to fully qualified keys.

The only two relevant checks I see are:

  1. https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py#L1655
  2. https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L1016

Here, we want to make sure that fully qualified keys from tied_weight_keys match the ones in modules_to_save. I propose that I do the following:

For every key in model.named_parameters, I perform a check similar to what is given in (1) and match it with tied_weight_keys. If both of them give a match, I remove the key from modules_to_save


setattr(peft_config, "modules_to_save", modules_to_save)

def _add_targets_to_tie(self, peft_config, tied_weight_keys):
tied_weight_keys = set(tied_weight_keys)
setattr(peft_config, "target_modules_to_tie", tied_weight_keys)

target_modules = set(getattr(peft_config, "target_modules", []) or [])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to consider the case that target_modules is a string and not a list of strings. If it's a string, we perform a regex match. Honestly, I'm not sure if there is a good solution. So far, I have 3 ideas:

  1. We could try to use the model.targeted_module_names attribute, which lists all targeted modules after the targets have been resolved. But that would mean that we need to first apply all LoRA layers and only then can we check for tied layers, which is the opposite order of how things are implemented right now.
  2. We could try using the string directly and then for example do something like: config.target_modules += f"|{missing_key}" but this is very brittle and won't work with all regexes, so I would like to avoid this.
  3. We could forbid using ensure_weight_tying=True and target_modules = <str>. Then we'd raise an error and tell users they have to pass a list of str if they want ensure_weight_tying.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, after going through the code a few more times, I realized this would not work for all the cases. I would go with the 1st approach and move the call to this function after model.targeted_module_names is updated

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. This has yet to be updated, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will do this in the next commit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan

Moving this after model.targeted_modules_name is populated is tough, as the loop which populates this (https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py#L773-L819) also needs to check and skip if the layers are tied.

Reversing the order would mean that we may end up adding adapters where they're not required. The subsequent code would become more involved, but essentially, we would have to remove adapters from all tied layers, re-add in embed_tokens, and proceed to tie remaining adapters to this. This is an opinionated solve which has the least complexity, according to me.

We can go with (1) in your original comment and redo a few things, or keep the current flow and go with (3).

I think the above might have become tough to follow 😅, so let me know and I can share some schematics. Will wait for your input.

target_modules.add("embed_tokens")

for m in tied_weight_keys:
target_modules.add(m)

peft_config.modules_to_tie = missing_keys
setattr(peft_config, "target_modules", target_modules)
87 changes: 71 additions & 16 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,30 @@ def _prepare_model(self, peft_config: PeftConfig, model: nn.Module):
"""
pass

@staticmethod
def _check_tied_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None:
"""
A helper method to check if the passed module's key name matches any of the tied modules

Args:
config (`PeftConfig`):
A config to match target modules from.
key (`str`):
A key to search any matches in config.

Returns:
`bool`
True if key matches any tied modules from config, False if no match found.
"""
_target_modules_to_tie = getattr(peft_config, "target_modules_to_tie", {}) or {}

if key in _target_modules_to_tie or any(
key.endswith(f".{target_key}") for target_key in _target_modules_to_tie
):
return True

return False

@staticmethod
def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None:
"""
Expand Down Expand Up @@ -699,6 +723,7 @@ def inject_adapter(
excluded_modules = []
unmatched_modules = []
targeted_modules_from_peft_config: list[str] = [] # only relevant if state_dict is passed
targets_to_tie: list[str] = []
# Note: If possible, all checks should be performed *at the start of this method*.
# This way, we can raise early if something goes wrong, without leaving the model
# in a bad (half-initialized) state.
Expand Down Expand Up @@ -792,6 +817,9 @@ def inject_adapter(
elif not result:
unmatched_modules.append(key)
else:
if self._check_tied_module_exists(peft_config, key):
targets_to_tie.append(key)
continue
self.targeted_module_names.append(key)
parent, target, target_name = _get_submodules(model, key)
self._check_target_module_compatiblity(peft_config, model, target_name)
Expand All @@ -805,6 +833,9 @@ def inject_adapter(
if key not in module_names:
unmatched_modules.append(key)
else:
if self._check_tied_module_exists(peft_config, key):
targets_to_tie.append(key)
continue
self.targeted_module_names.append(key)
parent, target, target_name = _get_submodules(model, key)
self._check_target_module_compatiblity(peft_config, model, target_name)
Expand All @@ -824,6 +855,15 @@ def inject_adapter(
peft_config=peft_config, model=model, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage
)

# Another loop for tying target modules
for key in targets_to_tie:
self.targeted_module_names.append(key)
parent, target, target_name = _get_submodules(model, key)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)

####################
# CHECK FOR ERRORS #
####################
Expand Down Expand Up @@ -910,15 +950,6 @@ def inject_adapter(
RuntimeWarning,
)

tied_target_modules = self._get_tied_target_modules(model=model)
if tied_target_modules:
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"See for example https://github.com/huggingface/peft/issues/2018."
)

################
# HOUSEKEEPING #
################
Expand Down Expand Up @@ -1198,6 +1229,24 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys):
"""
This method adds modules to tie to `peft_config` so that those modules can be tied downstream. By default this
method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this.

Check `peft.tuners.lora.LoraModel._add_modules_to_tie` for an example.
"""
msg = (
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
"but no implementation exists to tie the adapters. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"Check the discussion here: https://github.com/huggingface/peft/issues/2777"
)
warnings.warn(msg)

def _add_targets_to_tie(self, peft_config, tied_weight_keys):
"""
This method adds targets to tie to `peft_config` so that those modules can be tied downstream. By default this
method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this.

Check `peft.tuners.lora.LoraModel._add_targets_to_tie` for an example.
"""
msg = (
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
Expand All @@ -1210,27 +1259,33 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys):

def _check_tied_modules(self, model: nn.Module, peft_config):
"""
Checks if any of the tied layers are targetted via `modules_to_save`. Updates the `peft_config.modules_to_tie`
with any layers that needs to be tied
Checks if any of the tied layers are targetted via `modules_to_save` or `target_modules`. Updates the
`peft_config` in place with any layers/adapters that needs to be tied
"""
modules_to_save = set(getattr(peft_config, "modules_to_save", []) or [])
is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save)

target_modules = set(getattr(peft_config, "target_modules", []) or [])
is_embedding_in_target = any(m in EMBEDDING_LAYER_NAMES for m in target_modules)

tied_weight_keys = self._get_tied_weight_keys(model)

if getattr(peft_config, "ensure_weight_tying", False):
if is_embedding_to_save and tied_weight_keys:
self._add_modules_to_tie(peft_config, tied_weight_keys)
if (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole block, line 1288-1298, can be replaced with:

        if getattr(peft_config, "ensure_weight_tying", False):
            if tied_weight_keys:
                if is_embedding_to_save:
                    self._add_modules_to_tie(peft_config, tied_weight_keys)
                elif is_embedding_in_target:
                    self._add_targets_to_tie(peft_config, tied_weight_keys)
                else:
                    warnings.warn(
                        "You have requested `ensure_weight_tying`, but no tied modules are added in either "
                        "`modules_to_save` or `target_modules`"
                    )
            else:
                warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model")

I think this is cleaner.

if is_embedding_to_save:
self._add_modules_to_tie(peft_config, tied_weight_keys)
elif is_embedding_in_target:
self._add_targets_to_tie(peft_config, tied_weight_keys)

elif not is_embedding_to_save and tied_weight_keys:
elif not (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys:
warnings.warn(
"You have requested `ensure_weight_tying`, but no tied modules are added in `modules_to_save`"
"You have requested `ensure_weight_tying`, but no tied modules are added in either `modules_to_save` or `target_modules`"
)

elif not tied_weight_keys:
warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model")

elif is_embedding_to_save and tied_weight_keys:
elif (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys:
if hasattr(peft_config, "ensure_weight_tying"):
msg = (
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
Expand Down
Loading