From 4c6d15f18be50426f2f13f1e0268db0cc7fc986d Mon Sep 17 00:00:00 2001 From: romit Date: Wed, 29 Oct 2025 05:56:07 +0000 Subject: [PATCH 01/21] Tests and inital implementation for embed_tokens --- src/peft/tuners/lora/config.py | 8 ++- src/peft/tuners/lora/model.py | 27 +++++++++- src/peft/tuners/tuners_utils.py | 82 +++++++++++++++++++++++++++--- tests/test_initialization.py | 90 ++++++++++++++++++++++++++++++++- 4 files changed, 196 insertions(+), 11 deletions(-) diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 70170ca2c5..e7e0b1be4b 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -382,6 +382,11 @@ class LoraConfig(PeftConfig): `target_parameters`. As an example, for Llama4, you can pass: `target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`. Passing a string for regex matching is not implemented yet. + ensure_weight_tying (`bool`, *optional*) + Whether to tie weights or not after peft initialization. This will ensure that the adapters added to the + tied layers are also tied. This is only applicable for layers passed via `modules_to_save` and + `target_modules`. + """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) @@ -670,7 +675,7 @@ class LoraConfig(PeftConfig): "Whether to tie weights or not after peft initialization. " "This will ensure that the adapters added to the tied layers " "are also tied. This is only applicable for layers passed via " - "`modules_to_save`." + "`modules_to_save` and and `target_modules`." ) }, ) @@ -695,6 +700,7 @@ def __post_init__(self): if self.ensure_weight_tying: self.modules_to_tie = None + self.target_modules_to_tie = None if isinstance(self.target_parameters, str): raise TypeError("`target_parameters` must be a list of strings or None.") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 2e76e13ee1..cae125e848 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -249,9 +249,19 @@ def _create_and_replace( if adapter_name not in self.active_adapters: # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) - self._replace_module(parent, target_name, new_module, target) - def _replace_module(self, parent, child_name, new_module, child): + is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) + + self._replace_module( + parent=parent, + child_name=target_name, + new_module=new_module, + child=target, + is_tied=is_tied, + adapter_name=adapter_name, + ) + + def _replace_module(self, parent, child_name, new_module, child, is_tied, adapter_name): # override in LoraModel to handle quantized weights properly setattr(parent, child_name, new_module) @@ -279,6 +289,11 @@ def _replace_module(self, parent, child_name, new_module, child): if not any(p.device == meta for p in module.parameters()): module.to(weight.device) + if is_tied: + tied_module = self.model.get_input_embeddings() + new_module.lora_A[adapter_name].weight = tied_module.lora_embedding_B[adapter_name] + new_module.lora_B[adapter_name].weight = tied_module.lora_embedding_A[adapter_name] + @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters, @@ -811,3 +826,11 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys): missing_keys = set(tied_weight_keys) - modules_to_save peft_config.modules_to_tie = missing_keys + + def _add_targets_to_tie(self, peft_config, tied_weight_keys): + target_modules = set(getattr(peft_config, "target_modules", []) or []) + missing_keys = set(tied_weight_keys) - target_modules + + peft_config.target_modules_to_tie = missing_keys + for m in missing_keys: + peft_config.target_modules.add(m) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index fb8bc2f8ed..ef7967a576 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -774,6 +774,9 @@ def inject_adapter( if not key: continue + if key in getattr(peft_config, "target_module_to_tie", {}): + continue + # It is possible that we're adding an additional adapter, so if we encounter a key that clearly belongs to a # previous adapter we can skip here since we don't want to interfere with adapter internals. for adapter_key in existing_adapter_prefixes: @@ -824,6 +827,47 @@ def inject_adapter( peft_config=peft_config, model=model, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage ) + # Another loop for tying target modules + for key, module in named_modules: + if not key: + continue + + if key not in getattr(peft_config, "target_module_to_tie", {}): + continue + + if state_dict is None: + result = self._check_target_module_exists(peft_config, key) + if isinstance(result, _ExcludedModule): + excluded_modules.append(key) + elif not result: + unmatched_modules.append(key) + else: + self.targeted_module_names.append(key) + parent, target, target_name = _get_submodules(model, key) + self._check_target_module_compatiblity(peft_config, model, target_name) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + self._create_and_replace( + peft_config, adapter_name, target, target_name, parent, current_key=key + ) + else: + # use the state_dict to match modules instead + if key not in module_names: + unmatched_modules.append(key) + else: + self.targeted_module_names.append(key) + parent, target, target_name = _get_submodules(model, key) + self._check_target_module_compatiblity(peft_config, model, target_name) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + self._create_and_replace( + peft_config, adapter_name, target, target_name, parent, current_key=key + ) + + # still record what would have been matched via the config so that the two results can be compared + if self._check_target_module_exists(peft_config, key): + targeted_modules_from_peft_config.append(key) + #################### # CHECK FOR ERRORS # #################### @@ -1198,6 +1242,24 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys): """ This method adds modules to tie to `peft_config` so that those modules can be tied downstream. By default this method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. + + Check `peft.tuners.lora.LoraModel._add_modules_to_tie` for an example. + """ + msg = ( + "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " + "but no implementation exists to tie the adapters. " + "This can lead to complications, for example when merging the adapter " + "or converting your model to formats other than safetensors. " + "Check the discussion here: https://github.com/huggingface/peft/issues/2777" + ) + warnings.warn(msg) + + def _add_targets_to_tie(self, peft_config, tied_weight_keys): + """ + This method adds targets to tie to `peft_config` so that those modules can be tied downstream. By default this + method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. + + Check `peft.tuners.lora.LoraModel._add_targets_to_tie` for an example. """ msg = ( "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " @@ -1210,27 +1272,33 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys): def _check_tied_modules(self, model: nn.Module, peft_config): """ - Checks if any of the tied layers are targetted via `modules_to_save`. Updates the `peft_config.modules_to_tie` - with any layers that needs to be tied + Checks if any of the tied layers are targetted via `modules_to_save` or `target_modules`. Updates the + `peft_config` in place with any layers/adapters that needs to be tied """ modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save) + target_modules = set(getattr(peft_config, "target_modules", []) or []) + is_embedding_in_target = any(m in EMBEDDING_LAYER_NAMES for m in target_modules) + tied_weight_keys = self._get_tied_weight_keys(model) if getattr(peft_config, "ensure_weight_tying", False): - if is_embedding_to_save and tied_weight_keys: - self._add_modules_to_tie(peft_config, tied_weight_keys) + if (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: + if is_embedding_to_save: + self._add_modules_to_tie(peft_config, tied_weight_keys) + elif is_embedding_in_target: + self._add_targets_to_tie(peft_config, tied_weight_keys) - elif not is_embedding_to_save and tied_weight_keys: + elif not (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: warnings.warn( - "You have requested `ensure_weight_tying`, but no tied modules are added in `modules_to_save`" + "You have requested `ensure_weight_tying`, but no tied modules are added in either `modules_to_save` or `target_modules`" ) elif not tied_weight_keys: warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model") - elif is_embedding_to_save and tied_weight_keys: + elif (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: if hasattr(peft_config, "ensure_weight_tying"): msg = ( "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 37e65dcf25..62d745f0d1 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -4910,7 +4910,7 @@ def test_weight_tying_tied_model_no_embed_lora(self): ensure_weight_tying=True, ) - with pytest.warns(UserWarning, match="no tied modules are added in `modules_to_save`"): + with pytest.warns(UserWarning, match="no tied modules are added"): model = get_peft_model(model, embed_token_config) assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) @@ -4938,3 +4938,91 @@ def test_weight_tying_tied_model_lokr(self): assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( "LM head is not of type nn.linear" ) + + @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens"]) + def test_weight_tying_tied_model_target_modules_lora(self, layer): + # Same as `test_weight_tying_tied_model_lora` but the tied module is passed + # in `target_modules` instead of `modules_to_save`. + model = self.get_lm_model() + + embed_token_config = LoraConfig( + target_modules=[layer, "linear"], + ensure_weight_tying=True, + ) + + model = get_peft_model(model, embed_token_config) + + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + + # Since embed_tokens and lm_head weights are transpose of each other + # lm_head lora_A == embed_tokens lora_B + adapter_name = "default" + + embed_lora_A = model.base_model.model.model.embed_tokens.lora_embedding_A[adapter_name] + embed_lora_B = model.base_model.model.model.embed_tokens.lora_embedding_B[adapter_name] + + lm_lora_A = model.base_model.model.lm_head.lora_A[adapter_name].weight + lm_lora_B = model.base_model.model.lm_head.lora_B[adapter_name].weight + + assert torch.allclose(embed_lora_A, lm_lora_B) + assert torch.allclose(embed_lora_B, lm_lora_A) + assert embed_lora_A is lm_lora_B + assert embed_lora_B is lm_lora_A + + @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens"]) + def test_weight_tying_non_tied_model_target_modules_lora(self, layer): + # When model weights are not tied, ensure a warning is raised even if + # the tied module name is present in `target_modules`. + model = self.get_lm_model(tie_weights=False) + embed_token_config = LoraConfig( + target_modules=[layer, "linear"], + ensure_weight_tying=True, + ) + with pytest.warns(UserWarning, match="no tied modules were found in the model"): + model = get_peft_model(model, embed_token_config) + + if layer == "embed_tokens": + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) + elif layer == "lm_head": + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + else: + raise NotImplementedError("Layer type {layer} is not supported for this test") + + @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens"]) + def test_not_weight_tying_tied_model_target_modules_lora(self, layer): + # If ensure_weight_tying is False, a warning should be raised even when + # the tied module is present in `target_modules`. + model = self.get_lm_model() + embed_token_config = LoraConfig( + target_modules=[layer, "linear"], + ensure_weight_tying=False, + ) + with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"): + model = get_peft_model(model, embed_token_config) + + if layer == "embed_tokens": + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) + elif layer == "lm_head": + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + else: + raise NotImplementedError("Layer type {layer} is not supported for this test") + + def test_weight_tying_tied_model_target_modules_lokr(self): + from peft.tuners.lokr.layer import LoKrLayer + + model = self.get_lm_model() + + embed_token_config = LoKrConfig(target_modules=["linear", "lm_head"]) + + with pytest.warns(UserWarning, match="no implementation exists to tie the adapters"): + model = get_peft_model(model, embed_token_config) + + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding), ( + "Embed tokens is not updated as a LoRA layer" + ) + assert isinstance(model.base_model.model.lm_head, LoKrLayer), "LM head is not of type nn.linear" From 4b9122086f51461b073e207cc0c7ae42d4980327 Mon Sep 17 00:00:00 2001 From: romit Date: Thu, 30 Oct 2025 09:52:15 +0000 Subject: [PATCH 02/21] Minor fixes --- src/peft/tuners/lora/model.py | 13 ++++++-- src/peft/tuners/tuners_utils.py | 59 ++++++++++++++++++++++----------- src/peft/utils/other.py | 2 ++ tests/test_initialization.py | 20 +++++------ tests/test_trainable_tokens.py | 9 ----- 5 files changed, 61 insertions(+), 42 deletions(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index cae125e848..663bbb8702 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -290,9 +290,16 @@ def _replace_module(self, parent, child_name, new_module, child, is_tied, adapte module.to(weight.device) if is_tied: - tied_module = self.model.get_input_embeddings() - new_module.lora_A[adapter_name].weight = tied_module.lora_embedding_B[adapter_name] - new_module.lora_B[adapter_name].weight = tied_module.lora_embedding_A[adapter_name] + if child_name == "embed_tokens": + tied_module = self.model.get_output_embeddings() + new_module.lora_embedding_A[adapter_name] = tied_module.lora_B[adapter_name].weight.T + new_module.lora_embedding_B[adapter_name] = tied_module.lora_A[adapter_name].weight.T + elif child_name == "lm_head": + tied_module = self.model.get_input_embeddings() + new_module.lora_A[adapter_name].weight = tied_module.lora_embedding_B[adapter_name].T + new_module.lora_B[adapter_name].weight = tied_module.lora_embedding_A[adapter_name].T + else: + raise NotImplementedError(f"Tying adapters is not yet supported for layer {child_name}") @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index ef7967a576..fcc08e0d91 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -363,6 +363,30 @@ def _prepare_model(self, peft_config: PeftConfig, model: nn.Module): """ pass + @staticmethod + def _check_tied_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: + """ + A helper method to check if the passed module's key name matches any of the tied modules + + Args: + config (`PeftConfig`): + A config to match target modules from. + key (`str`): + A key to search any matches in config. + + Returns: + `bool` + True if key matches any tied modules from config, False if no match found. + """ + _target_modules_to_tie = getattr(peft_config, "target_modules_to_tie", {}) or {} + + if key in _target_modules_to_tie or any( + key.endswith(f".{target_key}") for target_key in _target_modules_to_tie + ): + return True + + return False + @staticmethod def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: """ @@ -774,7 +798,7 @@ def inject_adapter( if not key: continue - if key in getattr(peft_config, "target_module_to_tie", {}): + if self._check_tied_module_exists(peft_config, key): continue # It is possible that we're adding an additional adapter, so if we encounter a key that clearly belongs to a @@ -832,7 +856,7 @@ def inject_adapter( if not key: continue - if key not in getattr(peft_config, "target_module_to_tie", {}): + if not self._check_tied_module_exists(peft_config, key): continue if state_dict is None: @@ -954,15 +978,6 @@ def inject_adapter( RuntimeWarning, ) - tied_target_modules = self._get_tied_target_modules(model=model) - if tied_target_modules: - warnings.warn( - f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. " - "This can lead to complications, for example when merging the adapter " - "or converting your model to formats other than safetensors. " - "See for example https://github.com/huggingface/peft/issues/2018." - ) - ################ # HOUSEKEEPING # ################ @@ -1214,27 +1229,31 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]: tied_target_modules.append(target_module) return tied_target_modules - def _get_tied_weight_keys(self, model: nn.Module, prefix="") -> list[str]: + def _get_tied_weight_keys(self, model: nn.Module, prefix="") -> set[str]: """ Get the list of modules that needs to be tied For example: For models which have `embed_tokens` and `lm_head` as the tied keys this function will return [`lm_head`] - From: https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/modeling_utils.py#L563 + Adapted from: https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/modeling_utils.py#L563 """ - tied_weight_keys = [] + tied_weight_keys = set() if getattr(model, "_tied_weights_keys", None) is not None: - names = [f"{prefix}.{k}" if prefix else k for k in model._tied_weights_keys] - tied_weight_keys.extend(names) + names = {f"{prefix}.{k}" if prefix else k for k in model._tied_weights_keys} + tied_weight_keys.update(names) if getattr(model, "_dynamic_tied_weights_keys", None) is not None: - names = [f"{prefix}.{k}" if prefix else k for k in model._dynamic_tied_weights_keys] - tied_weight_keys.extend(names) + names = {f"{prefix}.{k}" if prefix else k for k in model._dynamic_tied_weights_keys} + tied_weight_keys.update(names) for name, submodule in model.named_children(): local_prefix = f"{prefix}.{name}" if prefix else name - tied_weight_keys.extend(self._get_tied_weight_keys(submodule, prefix=local_prefix)) + tied_weight_keys.update(self._get_tied_weight_keys(submodule, prefix=local_prefix)) + + tied_weight_keys = {".".join(n.split(".")[:-1]) for n in tied_weight_keys} - tied_weight_keys = [".".join(n.split(".")[:-1]) for n in tied_weight_keys] + # If there's at least one tied key add `embed_tokens` to the set + if tied_weight_keys: + tied_weight_keys.add("embed_tokens") return tied_weight_keys diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 07985e9b6f..8d73affaca 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -1437,6 +1437,8 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n # Tie the modules if any tied layer is passed in `modules_to_save`. # This should always be called after # `_set_trainable` is called for `modules_to_save`. + + # Tied module should either be input embedding or output embedding based on which module to tie tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name) _set_trainable( model, diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 62d745f0d1..59640a1593 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -4837,9 +4837,13 @@ def prepare_inputs_for_generation(self): def get_input_embeddings(self): return self.model.embed_tokens + def get_output_embeddings(self): + return self.lm_head + return CausalLM().eval().to(self.torch_device) - def test_weight_tying_tied_model_lora(self): + @pytest.mark.parametrize("layer", ["embed_tokens"]) + def test_weight_tying_tied_model_lora(self, layer): # If weight tying is enabled and `embed_tokens` # is passed as a `modules_to_save`, it needs to be ensured # that lm_head is tied to the adapter added to `embed_tokens` @@ -4847,19 +4851,15 @@ def test_weight_tying_tied_model_lora(self): model = self.get_lm_model() embed_token_config = LoraConfig( - modules_to_save=["embed_tokens"], + modules_to_save=[layer], target_modules=["linear"], ensure_weight_tying=True, ) model = get_peft_model(model, embed_token_config) - assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( - "Embed tokens is not added in Modules to Save" - ) - assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), ( - "Embed tokens and LM head types are not same" - ) + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper) + assert isinstance(model.base_model.model.lm_head, ModulesToSaveWrapper) # Validating that all model parameters are same embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) @@ -4965,8 +4965,8 @@ def test_weight_tying_tied_model_target_modules_lora(self, layer): lm_lora_A = model.base_model.model.lm_head.lora_A[adapter_name].weight lm_lora_B = model.base_model.model.lm_head.lora_B[adapter_name].weight - assert torch.allclose(embed_lora_A, lm_lora_B) - assert torch.allclose(embed_lora_B, lm_lora_A) + assert torch.allclose(embed_lora_A, lm_lora_B.T) + assert torch.allclose(embed_lora_B, lm_lora_A.T) assert embed_lora_A is lm_lora_B assert embed_lora_B is lm_lora_A diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index a642fe54c6..6a2b095a5b 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -677,15 +677,6 @@ def test_weight_tying_applied_when_model_is_tied_standalone(self, model_weight_t assert merged_model.model.decoder.embed_tokens.weight.data_ptr() == merged_model.lm_head.weight.data_ptr() - def test_weight_tying_normally_issues_warning(self, model_weight_tied, recwarn): - # When using models with weight tying and targeting the embedding or the tied layer should raise a warning. - peft_config = LoraConfig(target_modules=["embed_tokens"]) - peft_model = get_peft_model(model_weight_tied, peft_config) - - warnings = [w.message.args[0] for w in recwarn] - warnings = [msg for msg in warnings if "Model with `tie_word_embeddings=True` and the" in msg] - assert warnings - def test_weight_tying_state_dict_ignores_tied_weights(self, model_weight_tied): # since weight tying is currently not supported make sure that an error is raised when attempting # to use a model that has tied input/output embeddings From 46b803e4f2b218ed01a790fbf37bee67b00ef77d Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 31 Oct 2025 12:05:39 +0000 Subject: [PATCH 03/21] Fixed all tests and made updates to logic Signed-off-by: romit --- src/peft/tuners/lora/layer.py | 14 ++++ src/peft/tuners/lora/model.py | 61 +++++++++----- src/peft/tuners/tuners_utils.py | 78 ++++++------------ src/peft/utils/other.py | 2 - tests/test_initialization.py | 138 +++++++++++++++++--------------- 5 files changed, 150 insertions(+), 143 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index a338fac0f4..ba2a8e78d5 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -156,6 +156,8 @@ def update_layer( arrow_config: ArrowConfig = None, qalora_group_size: int = 32, inference_mode: bool = False, + is_tied: bool = False, + tied_adapters: dict = {}, **kwargs, ): # collect the kwargs @@ -195,6 +197,16 @@ def update_layer( # Actual trainable parameters self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias) + if is_tied: + if not tied_adapters: + raise RuntimeError("Layer is marked as tied, but tied adapters are not provided") + + lora_A_params = tied_adapters["lora_A"] + lora_B_params = tied_adapters["lora_B"] + + self.lora_A[adapter_name].weight = torch.nn.Parameter(lora_A_params) + self.lora_B[adapter_name].weight = torch.nn.Parameter(lora_B_params) + self.lora_bias[adapter_name] = lora_bias if use_rslora: @@ -631,6 +643,8 @@ def __init__( use_alora=use_alora, lora_bias=lora_bias, arrow_config=arrow_config, + is_tied=kwargs.get("is_tied", False), + tied_adapters=kwargs.get("tied_adapters"), ) self.is_target_conv_1d_layer = is_target_conv_1d_layer diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 663bbb8702..8e28f98ab7 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -187,6 +187,15 @@ def _create_and_replace( r = lora_config.rank_pattern.get(r_key, lora_config.r) alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) + is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) + tied_adapters = {} + if is_tied: + tied_module = self.model.get_input_embeddings() + emb_A = tied_module.lora_embedding_A[adapter_name] + emb_B = tied_module.lora_embedding_B[adapter_name] + + tied_adapters = {"lora_A": emb_B.t(), "lora_B": emb_A.t()} + kwargs = { "r": r, "lora_alpha": alpha, @@ -204,6 +213,8 @@ def _create_and_replace( "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), "parameter_name": parameter_name, + "is_tied": is_tied, + "tied_adapters": tied_adapters, } # for torchao merging, we need the get_apply_tensor_subclass from the quantization config @@ -250,8 +261,6 @@ def _create_and_replace( # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) - is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) - self._replace_module( parent=parent, child_name=target_name, @@ -289,18 +298,6 @@ def _replace_module(self, parent, child_name, new_module, child, is_tied, adapte if not any(p.device == meta for p in module.parameters()): module.to(weight.device) - if is_tied: - if child_name == "embed_tokens": - tied_module = self.model.get_output_embeddings() - new_module.lora_embedding_A[adapter_name] = tied_module.lora_B[adapter_name].weight.T - new_module.lora_embedding_B[adapter_name] = tied_module.lora_A[adapter_name].weight.T - elif child_name == "lm_head": - tied_module = self.model.get_input_embeddings() - new_module.lora_A[adapter_name].weight = tied_module.lora_embedding_B[adapter_name].T - new_module.lora_B[adapter_name].weight = tied_module.lora_embedding_A[adapter_name].T - else: - raise NotImplementedError(f"Tying adapters is not yet supported for layer {child_name}") - @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters, @@ -828,16 +825,36 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap return tensors_lora - def _add_modules_to_tie(self, peft_config, tied_weight_keys): - modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) - missing_keys = set(tied_weight_keys) - modules_to_save + def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): + """ + Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the + tied layers from `module_to_save`. Maintain a separate set for layers to be tied - peft_config.modules_to_tie = missing_keys + Args: + peft_config (LoraConfig): _description_ + tied_weight_keys (list[str]): _description_ + """ + tied_weight_keys = set(tied_weight_keys) + setattr(peft_config, "modules_to_tie", tied_weight_keys) + + modules_to_save = getattr(peft_config, "modules_to_save", []) or [] + if "embed_tokens" not in modules_to_save: + modules_to_save.append("embed_tokens") + + for m in tied_weight_keys: + if m in modules_to_save: + modules_to_save.remove(m) + + setattr(peft_config, "modules_to_save", modules_to_save) def _add_targets_to_tie(self, peft_config, tied_weight_keys): + tied_weight_keys = set(tied_weight_keys) + setattr(peft_config, "target_modules_to_tie", tied_weight_keys) + target_modules = set(getattr(peft_config, "target_modules", []) or []) - missing_keys = set(tied_weight_keys) - target_modules + target_modules.add("embed_tokens") + + for m in tied_weight_keys: + target_modules.add(m) - peft_config.target_modules_to_tie = missing_keys - for m in missing_keys: - peft_config.target_modules.add(m) + setattr(peft_config, "target_modules", target_modules) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index fcc08e0d91..c3d207ff5b 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -723,6 +723,7 @@ def inject_adapter( excluded_modules = [] unmatched_modules = [] targeted_modules_from_peft_config: list[str] = [] # only relevant if state_dict is passed + targets_to_tie: list[str] = [] # Note: If possible, all checks should be performed *at the start of this method*. # This way, we can raise early if something goes wrong, without leaving the model # in a bad (half-initialized) state. @@ -798,9 +799,6 @@ def inject_adapter( if not key: continue - if self._check_tied_module_exists(peft_config, key): - continue - # It is possible that we're adding an additional adapter, so if we encounter a key that clearly belongs to a # previous adapter we can skip here since we don't want to interfere with adapter internals. for adapter_key in existing_adapter_prefixes: @@ -819,6 +817,9 @@ def inject_adapter( elif not result: unmatched_modules.append(key) else: + if self._check_tied_module_exists(peft_config, key): + targets_to_tie.append(key) + continue self.targeted_module_names.append(key) parent, target, target_name = _get_submodules(model, key) self._check_target_module_compatiblity(peft_config, model, target_name) @@ -832,6 +833,9 @@ def inject_adapter( if key not in module_names: unmatched_modules.append(key) else: + if self._check_tied_module_exists(peft_config, key): + targets_to_tie.append(key) + continue self.targeted_module_names.append(key) parent, target, target_name = _get_submodules(model, key) self._check_target_module_compatiblity(peft_config, model, target_name) @@ -852,45 +856,13 @@ def inject_adapter( ) # Another loop for tying target modules - for key, module in named_modules: - if not key: - continue - - if not self._check_tied_module_exists(peft_config, key): - continue - - if state_dict is None: - result = self._check_target_module_exists(peft_config, key) - if isinstance(result, _ExcludedModule): - excluded_modules.append(key) - elif not result: - unmatched_modules.append(key) - else: - self.targeted_module_names.append(key) - parent, target, target_name = _get_submodules(model, key) - self._check_target_module_compatiblity(peft_config, model, target_name) - ctx = init_empty_weights if low_cpu_mem_usage else nullcontext - with ctx(): - self._create_and_replace( - peft_config, adapter_name, target, target_name, parent, current_key=key - ) - else: - # use the state_dict to match modules instead - if key not in module_names: - unmatched_modules.append(key) - else: - self.targeted_module_names.append(key) - parent, target, target_name = _get_submodules(model, key) - self._check_target_module_compatiblity(peft_config, model, target_name) - ctx = init_empty_weights if low_cpu_mem_usage else nullcontext - with ctx(): - self._create_and_replace( - peft_config, adapter_name, target, target_name, parent, current_key=key - ) - - # still record what would have been matched via the config so that the two results can be compared - if self._check_target_module_exists(peft_config, key): - targeted_modules_from_peft_config.append(key) + for key in targets_to_tie: + self.targeted_module_names.append(key) + parent, target, target_name = _get_submodules(model, key) + self._check_target_module_compatiblity(peft_config, model, target_name) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) #################### # CHECK FOR ERRORS # @@ -1229,31 +1201,27 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]: tied_target_modules.append(target_module) return tied_target_modules - def _get_tied_weight_keys(self, model: nn.Module, prefix="") -> set[str]: + def _get_tied_weight_keys(self, model: nn.Module, prefix="") -> list[str]: """ Get the list of modules that needs to be tied For example: For models which have `embed_tokens` and `lm_head` as the tied keys this function will return [`lm_head`] - Adapted from: https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/modeling_utils.py#L563 + From: https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/modeling_utils.py#L563 """ - tied_weight_keys = set() + tied_weight_keys = [] if getattr(model, "_tied_weights_keys", None) is not None: - names = {f"{prefix}.{k}" if prefix else k for k in model._tied_weights_keys} - tied_weight_keys.update(names) + names = [f"{prefix}.{k}" if prefix else k for k in model._tied_weights_keys] + tied_weight_keys.extend(names) if getattr(model, "_dynamic_tied_weights_keys", None) is not None: - names = {f"{prefix}.{k}" if prefix else k for k in model._dynamic_tied_weights_keys} - tied_weight_keys.update(names) + names = [f"{prefix}.{k}" if prefix else k for k in model._dynamic_tied_weights_keys] + tied_weight_keys.extend(names) for name, submodule in model.named_children(): local_prefix = f"{prefix}.{name}" if prefix else name - tied_weight_keys.update(self._get_tied_weight_keys(submodule, prefix=local_prefix)) - - tied_weight_keys = {".".join(n.split(".")[:-1]) for n in tied_weight_keys} + tied_weight_keys.extend(self._get_tied_weight_keys(submodule, prefix=local_prefix)) - # If there's at least one tied key add `embed_tokens` to the set - if tied_weight_keys: - tied_weight_keys.add("embed_tokens") + tied_weight_keys = [".".join(n.split(".")[:-1]) for n in tied_weight_keys] return tied_weight_keys diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 8d73affaca..07985e9b6f 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -1437,8 +1437,6 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n # Tie the modules if any tied layer is passed in `modules_to_save`. # This should always be called after # `_set_trainable` is called for `modules_to_save`. - - # Tied module should either be input embedding or output embedding based on which module to tie tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name) _set_trainable( model, diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 59640a1593..ef37160524 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -4842,7 +4842,7 @@ def get_output_embeddings(self): return CausalLM().eval().to(self.torch_device) - @pytest.mark.parametrize("layer", ["embed_tokens"]) + @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) def test_weight_tying_tied_model_lora(self, layer): # If weight tying is enabled and `embed_tokens` # is passed as a `modules_to_save`, it needs to be ensured @@ -4851,7 +4851,7 @@ def test_weight_tying_tied_model_lora(self, layer): model = self.get_lm_model() embed_token_config = LoraConfig( - modules_to_save=[layer], + modules_to_save=layer if isinstance(layer, list) else [layer], target_modules=["linear"], ensure_weight_tying=True, ) @@ -4867,41 +4867,46 @@ def test_weight_tying_tied_model_lora(self, layer): for k in embed_np.keys(): assert torch.allclose(embed_np[k], lm_head_np[k]) - assert embed_np[k] is lm_head_np[k] + assert embed_np[k].data_ptr() == lm_head_np[k].data_ptr() - def test_weight_tying_non_tied_model_lora(self): - model = self.get_lm_model(tie_weights=False) + @pytest.mark.parametrize( + "layer,tie_weights", + [ + ("lm_head", True), + ("lm_head", False), + ("embed_tokens", True), + ("embed_tokens", False), + (["embed_tokens", "lm_head"], True), + (["embed_tokens", "lm_head"], False), + ], + ) + def test_alt_weight_tying_tied_model_lora(self, layer, tie_weights): + model = self.get_lm_model(tie_weights=tie_weights) embed_token_config = LoraConfig( - modules_to_save=["embed_tokens"], + modules_to_save=layer if isinstance(layer, list) else [layer], target_modules=["linear"], - ensure_weight_tying=True, + ensure_weight_tying=not tie_weights, ) - with pytest.warns(UserWarning, match="no tied modules were found in the model"): - model = get_peft_model(model, embed_token_config) - assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( - "Embed tokens is not added in Modules to Save" - ) - assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( - "LM head is not of type nn.linear" - ) + if tie_weights: + wrn_msg = "`ensure_weight_tying` is not set to True" + else: + wrn_msg = "no tied modules were found in the model" - def test_not_weight_tying_tied_model_lora(self): - model = self.get_lm_model() - embed_token_config = LoraConfig( - modules_to_save=["embed_tokens"], - target_modules=["linear"], - ensure_weight_tying=False, - ) - with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"): + with pytest.warns(UserWarning, match=wrn_msg): model = get_peft_model(model, embed_token_config) - assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( - "Embed tokens is not added in Modules to Save" - ) - assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( - "LM head is not of type nn.linear" - ) + if layer == "embed_tokens": + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper) + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) + elif layer == "lm_head": + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) + assert isinstance(model.base_model.model.lm_head, ModulesToSaveWrapper) + elif layer == ["embed_tokens", "lm_head"]: + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper) + assert isinstance(model.base_model.model.lm_head, ModulesToSaveWrapper) + else: + raise NotImplementedError("Layer type {layer} is not supported for this test") def test_weight_tying_tied_model_no_embed_lora(self): model = self.get_lm_model() @@ -4935,18 +4940,16 @@ def test_weight_tying_tied_model_lokr(self): assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( "Embed tokens is not added in Modules to Save" ) - assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( - "LM head is not of type nn.linear" - ) + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) - @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens"]) + @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) def test_weight_tying_tied_model_target_modules_lora(self, layer): # Same as `test_weight_tying_tied_model_lora` but the tied module is passed # in `target_modules` instead of `modules_to_save`. model = self.get_lm_model() embed_token_config = LoraConfig( - target_modules=[layer, "linear"], + target_modules=["linear"] + layer if isinstance(layer, list) else [layer], ensure_weight_tying=True, ) @@ -4967,40 +4970,35 @@ def test_weight_tying_tied_model_target_modules_lora(self, layer): assert torch.allclose(embed_lora_A, lm_lora_B.T) assert torch.allclose(embed_lora_B, lm_lora_A.T) - assert embed_lora_A is lm_lora_B - assert embed_lora_B is lm_lora_A + assert embed_lora_A.data_ptr() == lm_lora_B.data_ptr() + assert embed_lora_B.data_ptr() == lm_lora_A.data_ptr() - @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens"]) - def test_weight_tying_non_tied_model_target_modules_lora(self, layer): + @pytest.mark.parametrize( + "layer,tie_weights", + [ + ("lm_head", True), + ("lm_head", False), + ("embed_tokens", True), + ("embed_tokens", False), + (["embed_tokens", "lm_head"], True), + (["embed_tokens", "lm_head"], False), + ], + ) + def test_alt_weight_tying_tied_model_target_modules_lora(self, layer, tie_weights): # When model weights are not tied, ensure a warning is raised even if # the tied module name is present in `target_modules`. - model = self.get_lm_model(tie_weights=False) + model = self.get_lm_model(tie_weights=tie_weights) embed_token_config = LoraConfig( - target_modules=[layer, "linear"], - ensure_weight_tying=True, + target_modules=["linear"] + layer if isinstance(layer, list) else [layer], + ensure_weight_tying=not tie_weights, ) - with pytest.warns(UserWarning, match="no tied modules were found in the model"): - model = get_peft_model(model, embed_token_config) - if layer == "embed_tokens": - assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) - assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) - elif layer == "lm_head": - assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) - assert isinstance(model.base_model.model.lm_head, LoraLayer) + if tie_weights: + wrn_msg = "`ensure_weight_tying` is not set to True" else: - raise NotImplementedError("Layer type {layer} is not supported for this test") + wrn_msg = "no tied modules were found in the model" - @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens"]) - def test_not_weight_tying_tied_model_target_modules_lora(self, layer): - # If ensure_weight_tying is False, a warning should be raised even when - # the tied module is present in `target_modules`. - model = self.get_lm_model() - embed_token_config = LoraConfig( - target_modules=[layer, "linear"], - ensure_weight_tying=False, - ) - with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"): + with pytest.warns(UserWarning, match=wrn_msg): model = get_peft_model(model, embed_token_config) if layer == "embed_tokens": @@ -5009,6 +5007,20 @@ def test_not_weight_tying_tied_model_target_modules_lora(self, layer): elif layer == "lm_head": assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) assert isinstance(model.base_model.model.lm_head, LoraLayer) + elif layer == ["embed_tokens", "lm_head"]: + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + + adapter_name = "default" + + embed_lora_A = model.base_model.model.model.embed_tokens.lora_embedding_A[adapter_name] + embed_lora_B = model.base_model.model.model.embed_tokens.lora_embedding_B[adapter_name] + + lm_lora_A = model.base_model.model.lm_head.lora_A[adapter_name].weight + lm_lora_B = model.base_model.model.lm_head.lora_B[adapter_name].weight + + assert embed_lora_A.data_ptr() != lm_lora_B.data_ptr() + assert embed_lora_B.data_ptr() != lm_lora_A.data_ptr() else: raise NotImplementedError("Layer type {layer} is not supported for this test") @@ -5022,7 +5034,5 @@ def test_weight_tying_tied_model_target_modules_lokr(self): with pytest.warns(UserWarning, match="no implementation exists to tie the adapters"): model = get_peft_model(model, embed_token_config) - assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding), ( - "Embed tokens is not updated as a LoRA layer" - ) - assert isinstance(model.base_model.model.lm_head, LoKrLayer), "LM head is not of type nn.linear" + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) + assert isinstance(model.base_model.model.lm_head, LoKrLayer) From 37b1e0642897fa3ef81046663b16dda7e6553110 Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 31 Oct 2025 12:17:03 +0000 Subject: [PATCH 04/21] Nit --- src/peft/tuners/lora/model.py | 11 ++--------- tests/test_initialization.py | 3 --- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 8e28f98ab7..41b5abf02d 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -261,16 +261,9 @@ def _create_and_replace( # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) - self._replace_module( - parent=parent, - child_name=target_name, - new_module=new_module, - child=target, - is_tied=is_tied, - adapter_name=adapter_name, - ) + self._replace_module(parent=parent, child_name=target_name, new_module=new_module, child=target) - def _replace_module(self, parent, child_name, new_module, child, is_tied, adapter_name): + def _replace_module(self, parent, child_name, new_module, child): # override in LoraModel to handle quantized weights properly setattr(parent, child_name, new_module) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index ef37160524..ededc682e3 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -4837,9 +4837,6 @@ def prepare_inputs_for_generation(self): def get_input_embeddings(self): return self.model.embed_tokens - def get_output_embeddings(self): - return self.lm_head - return CausalLM().eval().to(self.torch_device) @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) From 8388aa869473a60589a01e6950ea0583d3612783 Mon Sep 17 00:00:00 2001 From: romit Date: Tue, 4 Nov 2025 12:17:17 +0000 Subject: [PATCH 05/21] Added contigious check for export Signed-off-by: romit --- src/peft/peft_model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 3b7e636416..aceaba0990 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -315,6 +315,11 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, output_state_dict = save_mutated_as_lora( peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs ) + + for k, v in output_state_dict.items(): + if not v.is_contiguous(): + output_state_dict[k] = v.contiguous() + safe_save_file( output_state_dict, os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), From cd6c6d0f8f69bc85d4395f306ff1de348d5134cd Mon Sep 17 00:00:00 2001 From: r0 <11757603+romitjain@users.noreply.github.com> Date: Tue, 4 Nov 2025 17:57:33 +0530 Subject: [PATCH 06/21] Apply suggestion from @BenjaminBossan Co-authored-by: Benjamin Bossan --- src/peft/tuners/lora/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index ba2a8e78d5..e0774cbcb7 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -157,7 +157,7 @@ def update_layer( qalora_group_size: int = 32, inference_mode: bool = False, is_tied: bool = False, - tied_adapters: dict = {}, + tied_adapters: Optional[dict[str, nn.Parameter]] = None, **kwargs, ): # collect the kwargs From 0cb44e8bdd329201d0c68f1cb572f5871b4bbc7d Mon Sep 17 00:00:00 2001 From: romit Date: Wed, 5 Nov 2025 12:18:45 +0000 Subject: [PATCH 07/21] Addressed PR comments Signed-off-by: romit --- src/peft/tuners/lora/layer.py | 9 ++++----- src/peft/tuners/lora/model.py | 20 +++++++++++++++----- src/peft/tuners/tuners_utils.py | 17 +++++++++++++---- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index e0774cbcb7..0ed75d0fbd 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -156,7 +156,6 @@ def update_layer( arrow_config: ArrowConfig = None, qalora_group_size: int = 32, inference_mode: bool = False, - is_tied: bool = False, tied_adapters: Optional[dict[str, nn.Parameter]] = None, **kwargs, ): @@ -197,10 +196,11 @@ def update_layer( # Actual trainable parameters self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias) - if is_tied: - if not tied_adapters: - raise RuntimeError("Layer is marked as tied, but tied adapters are not provided") + # Tying adapters is only implemented for Linear layers + # where the source is the embedding layer. + # Currently, this is the most prevelant way of tying layers (weight tying) + if tied_adapters: lora_A_params = tied_adapters["lora_A"] lora_B_params = tied_adapters["lora_B"] @@ -643,7 +643,6 @@ def __init__( use_alora=use_alora, lora_bias=lora_bias, arrow_config=arrow_config, - is_tied=kwargs.get("is_tied", False), tied_adapters=kwargs.get("tied_adapters"), ) self.is_target_conv_1d_layer = is_target_conv_1d_layer diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 41b5abf02d..598f202ca4 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -187,6 +187,8 @@ def _create_and_replace( r = lora_config.rank_pattern.get(r_key, lora_config.r) alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) + # Checks if the target is marked as a tied layer + # If true, we add the reference to lora adapters of embedding layer in `tied_adapters` is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) tied_adapters = {} if is_tied: @@ -213,7 +215,6 @@ def _create_and_replace( "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), "parameter_name": parameter_name, - "is_tied": is_tied, "tied_adapters": tied_adapters, } @@ -824,8 +825,8 @@ def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st tied layers from `module_to_save`. Maintain a separate set for layers to be tied Args: - peft_config (LoraConfig): _description_ - tied_weight_keys (list[str]): _description_ + peft_config (LoraConfig) + tied_weight_keys (list[str]) """ tied_weight_keys = set(tied_weight_keys) setattr(peft_config, "modules_to_tie", tied_weight_keys) @@ -840,7 +841,15 @@ def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st setattr(peft_config, "modules_to_save", modules_to_save) - def _add_targets_to_tie(self, peft_config, tied_weight_keys): + def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): + """ + Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the + tied layers from `target_modules`. Maintain a separate set for layers to be tied + + Args: + peft_config (LoraConfig) + tied_weight_keys (list[str]) + """ tied_weight_keys = set(tied_weight_keys) setattr(peft_config, "target_modules_to_tie", tied_weight_keys) @@ -848,6 +857,7 @@ def _add_targets_to_tie(self, peft_config, tied_weight_keys): target_modules.add("embed_tokens") for m in tied_weight_keys: - target_modules.add(m) + if m in target_modules: + target_modules.remove(m) setattr(peft_config, "target_modules", target_modules) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index c3d207ff5b..de475c10a0 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -812,14 +812,18 @@ def inject_adapter( if state_dict is None: # normal mechanism: match the modules using the peft_config result = self._check_target_module_exists(peft_config, key) + # If the module is a tied layer, then we skip injecting + # any adapter here and tie it later to the adapter of the source layer. + # In this loop we only add adapters to the source layer (eg: embed_tokens) + # Only applicable if `ensure_weight_tying = True` for LoraConfig + if self._check_tied_module_exists(peft_config, key): + targets_to_tie.append(key) + continue if isinstance(result, _ExcludedModule): excluded_modules.append(key) elif not result: unmatched_modules.append(key) else: - if self._check_tied_module_exists(peft_config, key): - targets_to_tie.append(key) - continue self.targeted_module_names.append(key) parent, target, target_name = _get_submodules(model, key) self._check_target_module_compatiblity(peft_config, model, target_name) @@ -833,6 +837,10 @@ def inject_adapter( if key not in module_names: unmatched_modules.append(key) else: + # If the module is a tied layer, then we skip injecting + # any adapter here and tie it later to the adapter of the source layer. + # In this loop we only add adapters to the source layer (eg: embed_tokens) + # Only applicable if `ensure_weight_tying = True` for LoraConfig if self._check_tied_module_exists(peft_config, key): targets_to_tie.append(key) continue @@ -855,7 +863,8 @@ def inject_adapter( peft_config=peft_config, model=model, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage ) - # Another loop for tying target modules + # Here we inject tied adapters for all the layers which were tied + # Only applicable if `ensure_weight_tying = True` for LoraConfig for key in targets_to_tie: self.targeted_module_names.append(key) parent, target, target_name = _get_submodules(model, key) From 628ce10fd835e61ed7242d2573c4079e6e991e88 Mon Sep 17 00:00:00 2001 From: r0 <11757603+romitjain@users.noreply.github.com> Date: Fri, 7 Nov 2025 10:31:18 +0530 Subject: [PATCH 08/21] Update src/peft/tuners/lora/model.py Co-authored-by: Benjamin Bossan --- src/peft/tuners/lora/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 598f202ca4..086b54fd9e 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -851,7 +851,7 @@ def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st tied_weight_keys (list[str]) """ tied_weight_keys = set(tied_weight_keys) - setattr(peft_config, "target_modules_to_tie", tied_weight_keys) + peft_config.target_modules_to_tie = tied_weight_keys target_modules = set(getattr(peft_config, "target_modules", []) or []) target_modules.add("embed_tokens") From 602ce1075f3d0236b91e24ad8a19f2fb25ff6c55 Mon Sep 17 00:00:00 2001 From: r0 <11757603+romitjain@users.noreply.github.com> Date: Fri, 7 Nov 2025 10:31:37 +0530 Subject: [PATCH 09/21] Update src/peft/tuners/lora/model.py Co-authored-by: Benjamin Bossan --- src/peft/tuners/lora/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 086b54fd9e..ef4da7b259 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -860,4 +860,4 @@ def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st if m in target_modules: target_modules.remove(m) - setattr(peft_config, "target_modules", target_modules) + peft_config.target_modules = target_modules From e2d0345f0c9b7e30231174360015ba3d0f7f91ea Mon Sep 17 00:00:00 2001 From: r0 <11757603+romitjain@users.noreply.github.com> Date: Fri, 7 Nov 2025 10:36:00 +0530 Subject: [PATCH 10/21] Apply suggestions from code review Co-authored-by: Benjamin Bossan --- src/peft/tuners/lora/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index ef4da7b259..0385a3a12b 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -829,7 +829,7 @@ def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st tied_weight_keys (list[str]) """ tied_weight_keys = set(tied_weight_keys) - setattr(peft_config, "modules_to_tie", tied_weight_keys) + peft_config.modules_to_tie = tied_weight_keys modules_to_save = getattr(peft_config, "modules_to_save", []) or [] if "embed_tokens" not in modules_to_save: @@ -839,7 +839,7 @@ def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st if m in modules_to_save: modules_to_save.remove(m) - setattr(peft_config, "modules_to_save", modules_to_save) + peft_config.modules_to_save = modules_to_save def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): """ From 78800327726bf38670e18a746648767f19ae4f12 Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 7 Nov 2025 05:53:12 +0000 Subject: [PATCH 11/21] Removed redundant change --- src/peft/tuners/lora/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 598f202ca4..47d1506aa7 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -262,7 +262,7 @@ def _create_and_replace( # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) - self._replace_module(parent=parent, child_name=target_name, new_module=new_module, child=target) + self._replace_module(parent, target_name, new_module, target) def _replace_module(self, parent, child_name, new_module, child): # override in LoraModel to handle quantized weights properly From 46cca1e01be14b57bac8ed1707c5f537a51a5605 Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 7 Nov 2025 13:28:24 +0000 Subject: [PATCH 12/21] Handling target_modules as str Signed-off-by: romit --- src/peft/tuners/lora/model.py | 13 ++++++++++++- src/peft/tuners/tuners_utils.py | 10 ++++++++-- tests/test_initialization.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 8244a3d8e9..baecbca191 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -853,7 +853,18 @@ def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st tied_weight_keys = set(tied_weight_keys) peft_config.target_modules_to_tie = tied_weight_keys - target_modules = set(getattr(peft_config, "target_modules", []) or []) + raw_target_modules = getattr(peft_config, "target_modules", None) + + if isinstance(raw_target_modules, str): + # The way weight tying is handled for adapters, we always want to add + # lora adapters to the input embedding layer (embed_tokens) + # instead of output embedding lauyer. + if "lm_head" in raw_target_modules: + raw_target_modules = raw_target_modules.replace("lm_head", "embed_tokens") + peft_config.target_modules = raw_target_modules + return + + target_modules = set(raw_target_modules or []) target_modules.add("embed_tokens") for m in tied_weight_keys: diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index de475c10a0..61b3414e29 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -1274,8 +1274,14 @@ def _check_tied_modules(self, model: nn.Module, peft_config): modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save) - target_modules = set(getattr(peft_config, "target_modules", []) or []) - is_embedding_in_target = any(m in EMBEDDING_LAYER_NAMES for m in target_modules) + raw_target_modules = getattr(peft_config, "target_modules", None) + if isinstance(raw_target_modules, str): + is_embedding_in_target = any( + match_target_against_key(raw_target_modules, m) for m in EMBEDDING_LAYER_NAMES + ) + else: + target_modules = set(raw_target_modules or []) + is_embedding_in_target = any(m in EMBEDDING_LAYER_NAMES for m in target_modules) tied_weight_keys = self._get_tied_weight_keys(model) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index ededc682e3..23e075aa8a 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -4970,6 +4970,37 @@ def test_weight_tying_tied_model_target_modules_lora(self, layer): assert embed_lora_A.data_ptr() == lm_lora_B.data_ptr() assert embed_lora_B.data_ptr() == lm_lora_A.data_ptr() + @pytest.mark.parametrize("layer", [".*embed_tokens$", ".*lm_head$", ".*(embed_tokens|lm_head)$"]) + def test_weight_tying_tied_model_target_modules_str_lora(self, layer): + # Same as `test_weight_tying_tied_model_target_modules_lora` but the tied module + # are passed as str + model = self.get_lm_model() + + embed_token_config = LoraConfig( + target_modules=layer, + ensure_weight_tying=True, + ) + + model = get_peft_model(model, embed_token_config) + + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + + # Since embed_tokens and lm_head weights are transpose of each other + # lm_head lora_A == embed_tokens lora_B + adapter_name = "default" + + embed_lora_A = model.base_model.model.model.embed_tokens.lora_embedding_A[adapter_name] + embed_lora_B = model.base_model.model.model.embed_tokens.lora_embedding_B[adapter_name] + + lm_lora_A = model.base_model.model.lm_head.lora_A[adapter_name].weight + lm_lora_B = model.base_model.model.lm_head.lora_B[adapter_name].weight + + assert torch.allclose(embed_lora_A, lm_lora_B.T) + assert torch.allclose(embed_lora_B, lm_lora_A.T) + assert embed_lora_A.data_ptr() == lm_lora_B.data_ptr() + assert embed_lora_B.data_ptr() == lm_lora_A.data_ptr() + @pytest.mark.parametrize( "layer,tie_weights", [ From 2267a48095e2fe3eb001a734edffe152e26da447 Mon Sep 17 00:00:00 2001 From: r0 <11757603+romitjain@users.noreply.github.com> Date: Mon, 10 Nov 2025 21:21:13 +0530 Subject: [PATCH 13/21] Update src/peft/tuners/tuners_utils.py Co-authored-by: Benjamin Bossan --- src/peft/tuners/tuners_utils.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 61b3414e29..b9513563e1 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -378,14 +378,10 @@ def _check_tied_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Ma `bool` True if key matches any tied modules from config, False if no match found. """ - _target_modules_to_tie = getattr(peft_config, "target_modules_to_tie", {}) or {} - - if key in _target_modules_to_tie or any( - key.endswith(f".{target_key}") for target_key in _target_modules_to_tie - ): - return True - - return False + target_modules_to_tie = getattr(peft_config, "target_modules_to_tie", []) or [] + return key in target_modules_to_tie or any( + key.endswith(f".{target_key}") for target_key in target_modules_to_tie + ) @staticmethod def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: From 5d5b8e434454572f55fae6cc5a59c4c6ed0bd53a Mon Sep 17 00:00:00 2001 From: romit Date: Wed, 12 Nov 2025 07:04:18 +0000 Subject: [PATCH 14/21] Updated regex matching Signed-off-by: romit --- src/peft/tuners/lora/model.py | 5 ++--- src/peft/tuners/tuners_utils.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index baecbca191..142a0bc259 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -859,9 +859,8 @@ def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st # The way weight tying is handled for adapters, we always want to add # lora adapters to the input embedding layer (embed_tokens) # instead of output embedding lauyer. - if "lm_head" in raw_target_modules: - raw_target_modules = raw_target_modules.replace("lm_head", "embed_tokens") - peft_config.target_modules = raw_target_modules + raw_target_modules = f"(?:{raw_target_modules}|.*embed_tokens$)" + peft_config.target_modules = raw_target_modules return target_modules = set(raw_target_modules or []) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index b9513563e1..b687b72211 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -1282,18 +1282,17 @@ def _check_tied_modules(self, model: nn.Module, peft_config): tied_weight_keys = self._get_tied_weight_keys(model) if getattr(peft_config, "ensure_weight_tying", False): - if (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: + if tied_weight_keys: if is_embedding_to_save: self._add_modules_to_tie(peft_config, tied_weight_keys) elif is_embedding_in_target: self._add_targets_to_tie(peft_config, tied_weight_keys) - - elif not (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: - warnings.warn( - "You have requested `ensure_weight_tying`, but no tied modules are added in either `modules_to_save` or `target_modules`" - ) - - elif not tied_weight_keys: + else: + warnings.warn( + "You have requested `ensure_weight_tying`, but no tied modules are added in either " + "`modules_to_save` or `target_modules`" + ) + else: warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model") elif (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: From c7cfe4064d25316a689c24a746c371e07929aaf0 Mon Sep 17 00:00:00 2001 From: r0 <11757603+romitjain@users.noreply.github.com> Date: Thu, 13 Nov 2025 10:11:08 +0530 Subject: [PATCH 15/21] Apply suggestion from @BenjaminBossan Co-authored-by: Benjamin Bossan --- src/peft/tuners/lora/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 142a0bc259..2ec8e64511 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -859,7 +859,7 @@ def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st # The way weight tying is handled for adapters, we always want to add # lora adapters to the input embedding layer (embed_tokens) # instead of output embedding lauyer. - raw_target_modules = f"(?:{raw_target_modules}|.*embed_tokens$)" + raw_target_modules = rf"(?:{raw_target_modules}|.*embed_tokens$)" peft_config.target_modules = raw_target_modules return From 8294ec73a9b918ae521dde7424903d23dce7630c Mon Sep 17 00:00:00 2001 From: romit Date: Thu, 13 Nov 2025 07:15:42 +0000 Subject: [PATCH 16/21] Added find layer by tensor Signed-off-by: romit --- src/peft/peft_model.py | 4 +++ src/peft/tuners/lora/model.py | 22 +++++++++------- src/peft/tuners/tuners_utils.py | 45 +++++++++++++++++++++------------ 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index aceaba0990..42bb24417d 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -316,6 +316,10 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs ) + # Before exporting the parameters we need to make sure + # all the tensors are contigious. Tensors can become non contigiuous + # if they are a transpose view of another tensor. This can happen + # during adapter tying or parameter sharing. for k, v in output_state_dict.items(): if not v.is_contiguous(): output_state_dict[k] = v.contiguous() diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 2ec8e64511..85d65c0465 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -25,11 +25,7 @@ from torch import nn from peft.import_utils import is_bnb_4bit_available, is_bnb_available -from peft.tuners.tuners_utils import ( - BaseTuner, - BaseTunerLayer, - replicate_layers, -) +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, find_parameter_name_by_tensor, replicate_layers from peft.utils import ( TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, AuxiliaryTrainingWrapper, @@ -832,8 +828,13 @@ def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st peft_config.modules_to_tie = tied_weight_keys modules_to_save = getattr(peft_config, "modules_to_save", []) or [] - if "embed_tokens" not in modules_to_save: - modules_to_save.append("embed_tokens") + + embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) + # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name + embed_layer_name = embed_layer_name.replace(".weight", "") + + if embed_layer_name not in modules_to_save: + modules_to_save.append(embed_layer_name) for m in tied_weight_keys: if m in modules_to_save: @@ -854,17 +855,20 @@ def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st peft_config.target_modules_to_tie = tied_weight_keys raw_target_modules = getattr(peft_config, "target_modules", None) + embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) + # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name + embed_layer_name = embed_layer_name.replace(".weight", "") if isinstance(raw_target_modules, str): # The way weight tying is handled for adapters, we always want to add # lora adapters to the input embedding layer (embed_tokens) # instead of output embedding lauyer. - raw_target_modules = rf"(?:{raw_target_modules}|.*embed_tokens$)" + raw_target_modules = rf"(?:{raw_target_modules}|.*{embed_layer_name}$)" peft_config.target_modules = raw_target_modules return target_modules = set(raw_target_modules or []) - target_modules.add("embed_tokens") + target_modules.add(embed_layer_name) for m in tied_weight_keys: if m in target_modules: diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index b687b72211..2e44b64fe6 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -57,6 +57,15 @@ from ._buffer_dict import BufferDict +warn_msg_weight_tying = ( + "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " + "but no implementation exists to tie the adapters. " + "This can lead to complications, for example when merging the adapter " + "or converting your model to formats other than safetensors. " + "Check the discussion here: https://github.com/huggingface/peft/issues/2777" +) + + @contextmanager def onload_layer(layer): r""" @@ -1237,14 +1246,7 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys): Check `peft.tuners.lora.LoraModel._add_modules_to_tie` for an example. """ - msg = ( - "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " - "but no implementation exists to tie the adapters. " - "This can lead to complications, for example when merging the adapter " - "or converting your model to formats other than safetensors. " - "Check the discussion here: https://github.com/huggingface/peft/issues/2777" - ) - warnings.warn(msg) + warnings.warn(warn_msg_weight_tying) def _add_targets_to_tie(self, peft_config, tied_weight_keys): """ @@ -1253,14 +1255,7 @@ def _add_targets_to_tie(self, peft_config, tied_weight_keys): Check `peft.tuners.lora.LoraModel._add_targets_to_tie` for an example. """ - msg = ( - "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " - "but no implementation exists to tie the adapters. " - "This can lead to complications, for example when merging the adapter " - "or converting your model to formats other than safetensors. " - "Check the discussion here: https://github.com/huggingface/peft/issues/2777" - ) - warnings.warn(msg) + warnings.warn(warn_msg_weight_tying) def _check_tied_modules(self, model: nn.Module, peft_config): """ @@ -1975,6 +1970,24 @@ def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]): model.config.num_hidden_layers = len(new_layers) +def find_parameter_name_by_tensor(model: nn.Module, reference_tensor: torch.Tensor) -> str: + """ + Find layer name from the model by matching the reference tensor to the model parameters + + Args: + model (nn.Module): The model with named modules + reference_tensor (torch.Tensor): The reference tensor to find + + Returns: + str: Name of the layer + """ + for n, m in model.named_parameters(): + if m is reference_tensor: + return n + + return "" + + ############################### # FUNCTIONS FOR functional.py # ############################### From 1da895f00a3fc4f34c6d111cbf419e980a05a9c9 Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 14 Nov 2025 12:14:53 +0000 Subject: [PATCH 17/21] Fixed tests Signed-off-by: romit --- tests/test_tuners_utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 39c5e412f7..7fbafedbbf 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -1546,7 +1546,6 @@ def test_get_model_config_with_dataclass(self): class TestBaseTunerWarnForTiedEmbeddings: model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM" - warn_end_inject = "huggingface/peft/issues/2018." warn_end_merge = ( "# Now use the original model but in untied format\n" "model = AutoModelForCausalLM.from_pretrained(untied_model_dir)\n```\n" @@ -1564,28 +1563,16 @@ def _get_peft_model(self, tie_word_embeddings, target_module): def _is_warn_triggered(self, warning_list, endswith): return any(str(warning.message).endswith(endswith) for warning in warning_list) - def test_warn_for_tied_embeddings_inject(self, recwarn): - self._get_peft_model(tie_word_embeddings=True, target_module="lm_head") - assert self._is_warn_triggered(recwarn.list, self.warn_end_inject) - def test_warn_for_tied_embeddings_merge(self, recwarn): model = self._get_peft_model(tie_word_embeddings=True, target_module="lm_head") model.merge_and_unload() assert self._is_warn_triggered(recwarn.list, self.warn_end_merge) - def test_no_warn_for_untied_embeddings_inject(self, recwarn): - self._get_peft_model(tie_word_embeddings=False, target_module="lm_head") - assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject) - def test_no_warn_for_untied_embeddings_merge(self, recwarn): model_not_tied = self._get_peft_model(tie_word_embeddings=False, target_module="lm_head") model_not_tied.merge_and_unload() assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge) - def test_no_warn_for_no_target_module_inject(self, recwarn): - self._get_peft_model(tie_word_embeddings=True, target_module="q_proj") - assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject) - def test_no_warn_for_no_target_module_merge(self, recwarn): model_no_target_module = self._get_peft_model(tie_word_embeddings=True, target_module="q_proj") model_no_target_module.merge_and_unload() From d86ff7d028dcc2d19538697d2512134b98190f59 Mon Sep 17 00:00:00 2001 From: romit Date: Tue, 18 Nov 2025 16:58:47 +0000 Subject: [PATCH 18/21] Nit Signed-off-by: romit --- tests/test_initialization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 23e075aa8a..5ed6b5709f 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -64,6 +64,7 @@ set_peft_model_state_dict, ) from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING +from peft.tuners.lokr.layer import LoKrLayer from peft.tuners.lora.config import CordaConfig from peft.tuners.lora.corda import preprocess_corda from peft.tuners.lora.layer import LoraLayer @@ -5053,8 +5054,6 @@ def test_alt_weight_tying_tied_model_target_modules_lora(self, layer, tie_weight raise NotImplementedError("Layer type {layer} is not supported for this test") def test_weight_tying_tied_model_target_modules_lokr(self): - from peft.tuners.lokr.layer import LoKrLayer - model = self.get_lm_model() embed_token_config = LoKrConfig(target_modules=["linear", "lm_head"]) From dc03dd4a64b5cc279fc4dce085533baecb1ac834 Mon Sep 17 00:00:00 2001 From: romit Date: Wed, 19 Nov 2025 14:19:28 +0000 Subject: [PATCH 19/21] Small fix to ensure correct layer name gets saved for target modules Signed-off-by: romit --- src/peft/tuners/lora/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index bfab3633f9..238746be31 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -867,7 +867,7 @@ def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name - embed_layer_name = embed_layer_name.replace(".weight", "") + embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "") if embed_layer_name not in modules_to_save: modules_to_save.append(embed_layer_name) @@ -893,7 +893,7 @@ def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st raw_target_modules = getattr(peft_config, "target_modules", None) embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name - embed_layer_name = embed_layer_name.replace(".weight", "") + embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "") if isinstance(raw_target_modules, str): # The way weight tying is handled for adapters, we always want to add From dbb00960aceb0733daf9eb039eaf12859a8bb436 Mon Sep 17 00:00:00 2001 From: r0 <11757603+romitjain@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:39:35 +0530 Subject: [PATCH 20/21] Apply suggestions from code review Co-authored-by: githubnemo --- src/peft/tuners/lora/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 6444d2ad75..198fb06c8a 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -675,7 +675,7 @@ class LoraConfig(PeftConfig): "Whether to tie weights or not after peft initialization. " "This will ensure that the adapters added to the tied layers " "are also tied. This is only applicable for layers passed via " - "`modules_to_save` and and `target_modules`." + "`modules_to_save` and `target_modules`." ) }, ) From 67a71d63a18d3f28d2333eca23325087666d2326 Mon Sep 17 00:00:00 2001 From: romit Date: Mon, 15 Dec 2025 09:10:03 +0000 Subject: [PATCH 21/21] Updated matching logic Signed-off-by: romit --- src/peft/peft_model.py | 4 +- src/peft/tuners/lora/layer.py | 10 ++--- src/peft/tuners/lora/model.py | 70 ++++++++++++++++++++++----------- src/peft/tuners/tuners_utils.py | 8 ++-- 4 files changed, 58 insertions(+), 34 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 2b5f25f3f6..6f8f43bf83 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -329,8 +329,8 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs ) - # Before exporting the parameters we need to make sure - # all the tensors are contigious. Tensors can become non contigiuous + # Before exporting the parameters we need to make sure all the tensors are contigious as saving + # non-contiguous parameters is not supported. Tensors can become non contigiuous # if they are a transpose view of another tensor. This can happen # during adapter tying or parameter sharing. for k, v in output_state_dict.items(): diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 6f0ee90f70..d8952aab8a 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -156,7 +156,7 @@ def update_layer( arrow_config: ArrowConfig = None, qalora_group_size: int = 32, inference_mode: bool = False, - tied_adapters: Optional[dict[str, nn.Parameter]] = None, + tied_adapter: Optional[dict[str, nn.Parameter]] = None, use_bdlora=None, **kwargs, ): @@ -204,9 +204,9 @@ def update_layer( # Tying adapters is only implemented for Linear layers # where the source is the embedding layer. # Currently, this is the most prevelant way of tying layers (weight tying) - if tied_adapters: - lora_A_params = tied_adapters["lora_A"] - lora_B_params = tied_adapters["lora_B"] + if tied_adapter: + lora_A_params = tied_adapter["lora_A"] + lora_B_params = tied_adapter["lora_B"] self.lora_A[adapter_name].weight = torch.nn.Parameter(lora_A_params) self.lora_B[adapter_name].weight = torch.nn.Parameter(lora_B_params) @@ -648,7 +648,7 @@ def __init__( use_alora=use_alora, lora_bias=lora_bias, arrow_config=arrow_config, - tied_adapters=kwargs.get("tied_adapters"), + tied_adapter=kwargs.pop("tied_adapter", None), use_bdlora=use_bdlora, **kwargs, ) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 798055472f..eac5cf21e9 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -15,6 +15,7 @@ import math import operator +import re import warnings from contextlib import contextmanager from dataclasses import replace @@ -198,15 +199,15 @@ def _create_and_replace( alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) # Checks if the target is marked as a tied layer - # If true, we add the reference to lora adapters of embedding layer in `tied_adapters` + # If true, we add the reference to lora adapters of embedding layer in `tied_adapter` is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) - tied_adapters = {} + tied_adapter = {} if is_tied: tied_module = self.model.get_input_embeddings() emb_A = tied_module.lora_embedding_A[adapter_name] emb_B = tied_module.lora_embedding_B[adapter_name] - tied_adapters = {"lora_A": emb_B.t(), "lora_B": emb_A.t()} + tied_adapter = {"lora_A": emb_B.t(), "lora_B": emb_A.t()} kwargs = { "r": r, @@ -227,7 +228,7 @@ def _create_and_replace( "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), "parameter_name": parameter_name, - "tied_adapters": tied_adapters, + "tied_adapter": tied_adapter, } # for torchao merging, we need the get_apply_tensor_subclass from the quantization config @@ -874,49 +875,65 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap return tensors_lora - def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): + def _add_modules_to_save_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): """ - Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the - tied layers from `module_to_save`. Maintain a separate set for layers to be tied + Add embedding layer to `modules_to_save` and remove rest of the tied layers from `module_to_save`. Maintain a + separate set for layers to be tied in `peft_config.tied_weights_keys`. Args: - peft_config (LoraConfig) - tied_weight_keys (list[str]) + peft_config (LoraConfig) -- The configuration of the Lora model. + tied_weight_keys (list[str]) -- Contains the layers tied to the embedding layer. """ tied_weight_keys = set(tied_weight_keys) peft_config.modules_to_tie = tied_weight_keys modules_to_save = getattr(peft_config, "modules_to_save", []) or [] - embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) + embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings()) # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name - embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "") + if embed_layer_name.endswith(".weight"): + embed_layer_name = embed_layer_name.removesuffix(".weight") + prefix, sep, suffix = embed_layer_name.partition(".") + if sep and "model" in prefix: + embed_layer_name = suffix if embed_layer_name not in modules_to_save: modules_to_save.append(embed_layer_name) - for m in tied_weight_keys: - if m in modules_to_save: - modules_to_save.remove(m) + # Iterate over `tied_weight_keys` which are + # fully qualified keys and remove matching keys from + # `modules_to_save`. It will only remove first encounter + # in `module_to_save`, which should be safe, because `tied_weight_keys` + # is a unique set of keys + for key in tied_weight_keys: + for m in modules_to_save: + if re.match(rf"(^|.*\.){m}($|\..*)", key): + modules_to_save.remove(m) + break peft_config.modules_to_save = modules_to_save def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): """ - Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the - tied layers from `target_modules`. Maintain a separate set for layers to be tied + Add embedding layer to `target_modules` and remove rest of the tied layers from `target_modules`. Maintain a + separate set for layers to be tied in `peft_config.target_modules_to_tie` Args: - peft_config (LoraConfig) - tied_weight_keys (list[str]) + peft_config (LoraConfig) -- The configuration of the Lora model. + tied_weight_keys (list[str]) -- Contains the layers tied to the embedding layer. """ tied_weight_keys = set(tied_weight_keys) peft_config.target_modules_to_tie = tied_weight_keys raw_target_modules = getattr(peft_config, "target_modules", None) - embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) + + embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings()) # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name - embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "") + if embed_layer_name.endswith(".weight"): + embed_layer_name = embed_layer_name.removesuffix(".weight") + prefix, sep, suffix = embed_layer_name.partition(".") + if sep and "model" in prefix: + embed_layer_name = suffix if isinstance(raw_target_modules, str): # The way weight tying is handled for adapters, we always want to add @@ -929,8 +946,15 @@ def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[st target_modules = set(raw_target_modules or []) target_modules.add(embed_layer_name) - for m in tied_weight_keys: - if m in target_modules: - target_modules.remove(m) + # Iterate over `tied_weight_keys` which are + # fully qualified keys and remove matching keys from + # `target_modules`. It will only remove first encounter + # in `target_modules`, which should be safe, because `tied_weight_keys` + # is a unique set of keys + for key in tied_weight_keys: + for m in target_modules: + if re.match(rf"(^|.*\.){m}($|\..*)", key): + target_modules.remove(m) + break peft_config.target_modules = target_modules diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index c5a2f45d5c..cdeb99e063 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -1204,12 +1204,12 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]: def _get_module_names_tied_with_embedding(self) -> list[str]: return _get_module_names_tied_with_embedding(self) - def _add_modules_to_tie(self, peft_config, tied_weight_keys): + def _add_modules_to_save_to_tie(self, peft_config, tied_weight_keys): """ This method adds modules to tie to `peft_config` so that those modules can be tied downstream. By default this method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. - Check `peft.tuners.lora.LoraModel._add_modules_to_tie` for an example. + Check `peft.tuners.lora.LoraModel._add_modules_to_save_to_tie` for an example. """ warnings.warn(warn_msg_weight_tying) @@ -1244,7 +1244,7 @@ def _check_tied_modules(self, model: nn.Module, peft_config): if getattr(peft_config, "ensure_weight_tying", False): if tied_weight_keys: if is_embedding_to_save: - self._add_modules_to_tie(peft_config, tied_weight_keys) + self._add_modules_to_save_to_tie(peft_config, tied_weight_keys) elif is_embedding_in_target: self._add_targets_to_tie(peft_config, tied_weight_keys) else: @@ -1946,7 +1946,7 @@ def find_parameter_name_by_tensor(model: nn.Module, reference_tensor: torch.Tens Returns: str: Name of the layer """ - for n, m in model.named_parameters(): + for n, m in model.named_modules(): if m is reference_tensor: return n