diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index c4feb3e1be..104d81948a 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -329,6 +329,15 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, output_state_dict = save_mutated_as_lora( peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs ) + + # Before exporting the parameters we need to make sure all the tensors are contigious as saving + # non-contiguous parameters is not supported. Tensors can become non contigiuous + # if they are a transpose view of another tensor. This can happen + # during adapter tying or parameter sharing. + for k, v in output_state_dict.items(): + if not v.is_contiguous(): + output_state_dict[k] = v.contiguous() + safe_save_file( output_state_dict, os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index b4008f090a..f8e64be600 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -454,6 +454,11 @@ class LoraConfig(PeftConfig): `target_parameters`. As an example, for Llama4, you can pass: `target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`. Passing a string for regex matching is not implemented yet. + ensure_weight_tying (`bool`, *optional*) + Whether to tie weights or not after peft initialization. This will ensure that the adapters added to the + tied layers are also tied. This is only applicable for layers passed via `modules_to_save` and + `target_modules`. + """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) @@ -759,8 +764,8 @@ class LoraConfig(PeftConfig): "help": ( "Whether to tie weights or not after peft initialization. " "This will ensure that the adapters added to the tied layers " - "are also tied. This is applicable for layers passed via " - "`modules_to_save` and `trainable_token_indices`." + "are also tied. This is only applicable for layers passed via " + "`modules_to_save`, `target_modules` and `trainable_token_indices`." ) }, ) @@ -785,6 +790,7 @@ def __post_init__(self): if self.ensure_weight_tying: self.modules_to_tie = None + self.target_modules_to_tie = None if isinstance(self.target_parameters, str): raise TypeError("`target_parameters` must be a list of strings or None.") diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 13b5bd03cd..55e7a20b2a 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -159,6 +159,7 @@ def update_layer( arrow_config: ArrowConfig = None, qalora_group_size: int = 32, inference_mode: bool = False, + tied_adapter: Optional[dict[str, nn.Parameter]] = None, lora_ga_config=None, use_bdlora=None, **kwargs, @@ -203,6 +204,17 @@ def update_layer( # Actual trainable parameters self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias) + + # Tying adapters is only implemented for Linear layers + # where the source is the embedding layer. + # Currently, this is the most prevelant way of tying layers (weight tying) + if tied_adapter: + lora_A_params = tied_adapter["lora_A"] + lora_B_params = tied_adapter["lora_B"] + + self.lora_A[adapter_name].weight = torch.nn.Parameter(lora_A_params) + self.lora_B[adapter_name].weight = torch.nn.Parameter(lora_B_params) + self.lora_bias[adapter_name] = lora_bias if use_rslora: @@ -774,6 +786,7 @@ def __init__( use_alora=use_alora, lora_bias=lora_bias, arrow_config=arrow_config, + tied_adapter=kwargs.pop("tied_adapter", None), lora_ga_config=lora_ga_config, use_bdlora=use_bdlora, **kwargs, diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index e3c032d44d..f10cff0641 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -15,6 +15,7 @@ import math import operator +import re import warnings from contextlib import contextmanager from dataclasses import replace @@ -27,11 +28,7 @@ from torch import nn from peft.import_utils import is_bnb_4bit_available, is_bnb_available -from peft.tuners.tuners_utils import ( - BaseTuner, - BaseTunerLayer, - replicate_layers, -) +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, find_parameter_name_by_tensor, replicate_layers from peft.utils import ( TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, AuxiliaryTrainingWrapper, @@ -202,6 +199,17 @@ def _create_and_replace( r = lora_config.rank_pattern.get(r_key, lora_config.r) alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) + # Checks if the target is marked as a tied layer + # If true, we add the reference to lora adapters of embedding layer in `tied_adapter` + is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) + tied_adapter = {} + if is_tied: + tied_module = self.model.get_input_embeddings() + emb_A = tied_module.lora_embedding_A[adapter_name] + emb_B = tied_module.lora_embedding_B[adapter_name] + + tied_adapter = {"lora_A": emb_B.t(), "lora_B": emb_A.t()} + kwargs = { "r": r, "lora_alpha": alpha, @@ -222,6 +230,7 @@ def _create_and_replace( "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), "parameter_name": parameter_name, + "tied_adapter": tied_adapter, } # for torchao merging, we need the get_apply_tensor_subclass from the quantization config @@ -270,6 +279,7 @@ def _create_and_replace( if adapter_name not in self.active_adapters: # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) def _replace_module(self, parent, child_name, new_module, child): @@ -878,8 +888,86 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap return tensors_lora - def _add_modules_to_tie(self, peft_config, tied_weight_keys): - modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) - missing_keys = set(tied_weight_keys) - modules_to_save + def _add_modules_to_save_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): + """ + Add embedding layer to `modules_to_save` and remove rest of the tied layers from `module_to_save`. Maintain a + separate set for layers to be tied in `peft_config.tied_weights_keys`. + + Args: + peft_config (LoraConfig) -- The configuration of the Lora model. + tied_weight_keys (list[str]) -- Contains the layers tied to the embedding layer. + """ + tied_weight_keys = set(tied_weight_keys) + peft_config.modules_to_tie = tied_weight_keys + + modules_to_save = getattr(peft_config, "modules_to_save", []) or [] + + embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings()) + # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name + if embed_layer_name.endswith(".weight"): + embed_layer_name = embed_layer_name.removesuffix(".weight") + prefix, sep, suffix = embed_layer_name.partition(".") + if sep and "model" in prefix: + embed_layer_name = suffix + + if embed_layer_name not in modules_to_save: + modules_to_save.append(embed_layer_name) + + # Iterate over `tied_weight_keys` which are + # fully qualified keys and remove matching keys from + # `modules_to_save`. It will only remove first encounter + # in `module_to_save`, which should be safe, because `tied_weight_keys` + # is a unique set of keys + for key in tied_weight_keys: + for m in modules_to_save: + if re.match(rf"(^|.*\.){m}($|\..*)", key): + modules_to_save.remove(m) + break + + peft_config.modules_to_save = modules_to_save + + def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): + """ + Add embedding layer to `target_modules` and remove rest of the tied layers from `target_modules`. Maintain a + separate set for layers to be tied in `peft_config.target_modules_to_tie` + + Args: + peft_config (LoraConfig) -- The configuration of the Lora model. + tied_weight_keys (list[str]) -- Contains the layers tied to the embedding layer. + """ + tied_weight_keys = set(tied_weight_keys) + peft_config.target_modules_to_tie = tied_weight_keys + + raw_target_modules = getattr(peft_config, "target_modules", None) + + embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings()) + # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name + if embed_layer_name.endswith(".weight"): + embed_layer_name = embed_layer_name.removesuffix(".weight") + prefix, sep, suffix = embed_layer_name.partition(".") + if sep and "model" in prefix: + embed_layer_name = suffix + + if isinstance(raw_target_modules, str): + # The way weight tying is handled for adapters, we always want to add + # lora adapters to the input embedding layer (embed_tokens) + # instead of output embedding lauyer. + raw_target_modules = rf"(?:{raw_target_modules}|.*{embed_layer_name}$)" + peft_config.target_modules = raw_target_modules + return - peft_config.modules_to_tie = missing_keys + target_modules = set(raw_target_modules or []) + target_modules.add(embed_layer_name) + + # Iterate over `tied_weight_keys` which are + # fully qualified keys and remove matching keys from + # `target_modules`. It will only remove first encounter + # in `target_modules`, which should be safe, because `tied_weight_keys` + # is a unique set of keys + for key in tied_weight_keys: + for m in target_modules: + if re.match(rf"(^|.*\.){m}($|\..*)", key): + target_modules.remove(m) + break + + peft_config.target_modules = target_modules diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index a94fe4f2e2..1acb617747 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -58,6 +58,13 @@ from ._buffer_dict import BufferDict +warn_msg_weight_tying = ( + "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " + "but no implementation exists to tie the adapters. " + "This can lead to complications, for example when merging the adapter " + "or converting your model to formats other than safetensors. " + "Check the discussion here: https://github.com/huggingface/peft/issues/2777" +) _torch_supports_dtensor = version.parse(torch.__version__) >= version.parse("2.5.0") _torch_supports_distributed = _torch_supports_dtensor and torch.distributed.is_available() @@ -367,6 +374,26 @@ def _prepare_model(self, peft_config: PeftConfig, model: nn.Module): """ pass + @staticmethod + def _check_tied_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: + """ + A helper method to check if the passed module's key name matches any of the tied modules + + Args: + config (`PeftConfig`): + A config to match target modules from. + key (`str`): + A key to search any matches in config. + + Returns: + `bool` + True if key matches any tied modules from config, False if no match found. + """ + target_modules_to_tie = getattr(peft_config, "target_modules_to_tie", []) or [] + return key in target_modules_to_tie or any( + key.endswith(f".{target_key}") for target_key in target_modules_to_tie + ) + @staticmethod def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: """ @@ -703,6 +730,7 @@ def inject_adapter( excluded_modules = [] unmatched_modules = [] targeted_modules_from_peft_config: list[str] = [] # only relevant if state_dict is passed + targets_to_tie: list[str] = [] # Note: If possible, all checks should be performed *at the start of this method*. # This way, we can raise early if something goes wrong, without leaving the model # in a bad (half-initialized) state. @@ -793,6 +821,13 @@ def inject_adapter( if state_dict is None: # normal mechanism: match the modules using the peft_config result = self._check_target_module_exists(peft_config, key) + # If the module is a tied layer, then we skip injecting + # any adapter here and tie it later to the adapter of the source layer. + # In this loop we only add adapters to the source layer (eg: embed_tokens) + # Only applicable if `ensure_weight_tying = True` for LoraConfig + if self._check_tied_module_exists(peft_config, key): + targets_to_tie.append(key) + continue if isinstance(result, _ExcludedModule): excluded_modules.append(key) elif not result: @@ -813,6 +848,13 @@ def inject_adapter( if key not in module_names: unmatched_modules.append(key) else: + # If the module is a tied layer, then we skip injecting + # any adapter here and tie it later to the adapter of the source layer. + # In this loop we only add adapters to the source layer (eg: embed_tokens) + # Only applicable if `ensure_weight_tying = True` for LoraConfig + if self._check_tied_module_exists(peft_config, key): + targets_to_tie.append(key) + continue self.targeted_module_names.append(key) parent, target, target_name = _get_submodules(model, key) self._check_target_module_compatiblity(peft_config, model, target_name) @@ -832,6 +874,16 @@ def inject_adapter( peft_config=peft_config, model=model, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage ) + # Here we inject tied adapters for all the layers which were tied + # Only applicable if `ensure_weight_tying = True` for LoraConfig + for key in targets_to_tie: + self.targeted_module_names.append(key) + parent, target, target_name = _get_submodules(model, key) + self._check_target_module_compatiblity(peft_config, model, target_name) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) + #################### # CHECK FOR ERRORS # #################### @@ -918,15 +970,6 @@ def inject_adapter( RuntimeWarning, ) - tied_target_modules = self._get_tied_target_modules(model=model) - if tied_target_modules: - warnings.warn( - f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. " - "This can lead to complications, for example when merging the adapter " - "or converting your model to formats other than safetensors. " - "See for example https://github.com/huggingface/peft/issues/2018." - ) - ################ # HOUSEKEEPING # ################ @@ -1166,43 +1209,58 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]: def _get_module_names_tied_with_embedding(self) -> list[str]: return _get_module_names_tied_with_embedding(self) - def _add_modules_to_tie(self, peft_config, tied_weight_keys): + def _add_modules_to_save_to_tie(self, peft_config, tied_weight_keys): """ This method adds modules to tie to `peft_config` so that those modules can be tied downstream. By default this method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. + + Check `peft.tuners.lora.LoraModel._add_modules_to_save_to_tie` for an example. """ - msg = ( - "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " - "but no implementation exists to tie the adapters. " - "This can lead to complications, for example when merging the adapter " - "or converting your model to formats other than safetensors. " - "Check the discussion here: https://github.com/huggingface/peft/issues/2777" - ) - warnings.warn(msg) + warnings.warn(warn_msg_weight_tying) + + def _add_targets_to_tie(self, peft_config, tied_weight_keys): + """ + This method adds targets to tie to `peft_config` so that those modules can be tied downstream. By default this + method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. + + Check `peft.tuners.lora.LoraModel._add_targets_to_tie` for an example. + """ + warnings.warn(warn_msg_weight_tying) def _check_tied_modules(self, model: nn.Module, peft_config): """ - Checks if any of the tied layers are targetted via `modules_to_save`. Updates the `peft_config.modules_to_tie` - with any layers that needs to be tied + Checks if any of the tied layers are targetted via `modules_to_save` or `target_modules`. Updates the + `peft_config` in place with any layers/adapters that needs to be tied """ modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save) + raw_target_modules = getattr(peft_config, "target_modules", None) + if isinstance(raw_target_modules, str): + is_embedding_in_target = any( + match_target_against_key(raw_target_modules, m) for m in EMBEDDING_LAYER_NAMES + ) + else: + target_modules = set(raw_target_modules or []) + is_embedding_in_target = any(m in EMBEDDING_LAYER_NAMES for m in target_modules) + tied_weight_keys = self._get_module_names_tied_with_embedding() if getattr(peft_config, "ensure_weight_tying", False): - if is_embedding_to_save and tied_weight_keys: - self._add_modules_to_tie(peft_config, tied_weight_keys) - - elif not is_embedding_to_save and tied_weight_keys: - warnings.warn( - "You have requested `ensure_weight_tying`, but no tied modules are added in `modules_to_save`" - ) - - elif not tied_weight_keys: + if tied_weight_keys: + if is_embedding_to_save: + self._add_modules_to_save_to_tie(peft_config, tied_weight_keys) + elif is_embedding_in_target: + self._add_targets_to_tie(peft_config, tied_weight_keys) + else: + warnings.warn( + "You have requested `ensure_weight_tying`, but no tied modules are added in either " + "`modules_to_save` or `target_modules`" + ) + else: warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model") - elif is_embedding_to_save and tied_weight_keys: + elif (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: if hasattr(peft_config, "ensure_weight_tying"): msg = ( "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " @@ -1919,6 +1977,24 @@ def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]): model.config.num_hidden_layers = len(new_layers) +def find_parameter_name_by_tensor(model: nn.Module, reference_tensor: torch.Tensor) -> str: + """ + Find layer name from the model by matching the reference tensor to the model parameters + + Args: + model (nn.Module): The model with named modules + reference_tensor (torch.Tensor): The reference tensor to find + + Returns: + str: Name of the layer + """ + for n, m in model.named_modules(): + if m is reference_tensor: + return n + + return "" + + ############################### # FUNCTIONS FOR functional.py # ############################### diff --git a/tests/test_initialization.py b/tests/test_initialization.py index c72018854c..863018c911 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -65,6 +65,7 @@ set_peft_model_state_dict, ) from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING +from peft.tuners.lokr.layer import LoKrLayer from peft.tuners.lora.config import CordaConfig from peft.tuners.lora.corda import preprocess_corda from peft.tuners.lora.layer import LoraLayer @@ -4968,7 +4969,8 @@ def get_input_embeddings(self): return CausalLM().eval().to(self.torch_device) - def test_weight_tying_tied_model_lora(self): + @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) + def test_weight_tying_tied_model_lora(self, layer): # If weight tying is enabled and `embed_tokens` # is passed as a `modules_to_save`, it needs to be ensured # that lm_head is tied to the adapter added to `embed_tokens` @@ -4976,19 +4978,15 @@ def test_weight_tying_tied_model_lora(self): model = self.get_lm_model() embed_token_config = LoraConfig( - modules_to_save=["embed_tokens"], + modules_to_save=layer if isinstance(layer, list) else [layer], target_modules=["linear"], ensure_weight_tying=True, ) model = get_peft_model(model, embed_token_config) - assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( - "Embed tokens is not added in Modules to Save" - ) - assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), ( - "Embed tokens and LM head types are not same" - ) + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper) + assert isinstance(model.base_model.model.lm_head, ModulesToSaveWrapper) # Validating that all model parameters are same embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) @@ -4996,41 +4994,46 @@ def test_weight_tying_tied_model_lora(self): for k in embed_np.keys(): assert torch.allclose(embed_np[k], lm_head_np[k]) - assert embed_np[k] is lm_head_np[k] + assert embed_np[k].data_ptr() == lm_head_np[k].data_ptr() - def test_weight_tying_non_tied_model_lora(self): - model = self.get_lm_model(tie_weights=False) + @pytest.mark.parametrize( + "layer,tie_weights", + [ + ("lm_head", True), + ("lm_head", False), + ("embed_tokens", True), + ("embed_tokens", False), + (["embed_tokens", "lm_head"], True), + (["embed_tokens", "lm_head"], False), + ], + ) + def test_alt_weight_tying_tied_model_lora(self, layer, tie_weights): + model = self.get_lm_model(tie_weights=tie_weights) embed_token_config = LoraConfig( - modules_to_save=["embed_tokens"], + modules_to_save=layer if isinstance(layer, list) else [layer], target_modules=["linear"], - ensure_weight_tying=True, + ensure_weight_tying=not tie_weights, ) - with pytest.warns(UserWarning, match="no tied modules were found in the model"): - model = get_peft_model(model, embed_token_config) - assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( - "Embed tokens is not added in Modules to Save" - ) - assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( - "LM head is not of type nn.linear" - ) + if tie_weights: + wrn_msg = "`ensure_weight_tying` is not set to True" + else: + wrn_msg = "no tied modules were found in the model" - def test_not_weight_tying_tied_model_lora(self): - model = self.get_lm_model() - embed_token_config = LoraConfig( - modules_to_save=["embed_tokens"], - target_modules=["linear"], - ensure_weight_tying=False, - ) - with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"): + with pytest.warns(UserWarning, match=wrn_msg): model = get_peft_model(model, embed_token_config) - assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( - "Embed tokens is not added in Modules to Save" - ) - assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( - "LM head is not of type nn.linear" - ) + if layer == "embed_tokens": + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper) + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) + elif layer == "lm_head": + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) + assert isinstance(model.base_model.model.lm_head, ModulesToSaveWrapper) + elif layer == ["embed_tokens", "lm_head"]: + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper) + assert isinstance(model.base_model.model.lm_head, ModulesToSaveWrapper) + else: + raise NotImplementedError("Layer type {layer} is not supported for this test") def test_weight_tying_tied_model_no_embed_lora(self): model = self.get_lm_model() @@ -5039,7 +5042,7 @@ def test_weight_tying_tied_model_no_embed_lora(self): ensure_weight_tying=True, ) - with pytest.warns(UserWarning, match="no tied modules are added in `modules_to_save`"): + with pytest.warns(UserWarning, match="no tied modules are added"): model = get_peft_model(model, embed_token_config) assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) @@ -5064,6 +5067,128 @@ def test_weight_tying_tied_model_lokr(self): assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( "Embed tokens is not added in Modules to Save" ) - assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( - "LM head is not of type nn.linear" + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) + + @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) + def test_weight_tying_tied_model_target_modules_lora(self, layer): + # Same as `test_weight_tying_tied_model_lora` but the tied module is passed + # in `target_modules` instead of `modules_to_save`. + model = self.get_lm_model() + + embed_token_config = LoraConfig( + target_modules=["linear"] + layer if isinstance(layer, list) else [layer], + ensure_weight_tying=True, + ) + + model = get_peft_model(model, embed_token_config) + + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + + # Since embed_tokens and lm_head weights are transpose of each other + # lm_head lora_A == embed_tokens lora_B + adapter_name = "default" + + embed_lora_A = model.base_model.model.model.embed_tokens.lora_embedding_A[adapter_name] + embed_lora_B = model.base_model.model.model.embed_tokens.lora_embedding_B[adapter_name] + + lm_lora_A = model.base_model.model.lm_head.lora_A[adapter_name].weight + lm_lora_B = model.base_model.model.lm_head.lora_B[adapter_name].weight + + assert torch.allclose(embed_lora_A, lm_lora_B.T) + assert torch.allclose(embed_lora_B, lm_lora_A.T) + assert embed_lora_A.data_ptr() == lm_lora_B.data_ptr() + assert embed_lora_B.data_ptr() == lm_lora_A.data_ptr() + + @pytest.mark.parametrize("layer", [".*embed_tokens$", ".*lm_head$", ".*(embed_tokens|lm_head)$"]) + def test_weight_tying_tied_model_target_modules_str_lora(self, layer): + # Same as `test_weight_tying_tied_model_target_modules_lora` but the tied module + # are passed as str + model = self.get_lm_model() + + embed_token_config = LoraConfig( + target_modules=layer, + ensure_weight_tying=True, + ) + + model = get_peft_model(model, embed_token_config) + + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + + # Since embed_tokens and lm_head weights are transpose of each other + # lm_head lora_A == embed_tokens lora_B + adapter_name = "default" + + embed_lora_A = model.base_model.model.model.embed_tokens.lora_embedding_A[adapter_name] + embed_lora_B = model.base_model.model.model.embed_tokens.lora_embedding_B[adapter_name] + + lm_lora_A = model.base_model.model.lm_head.lora_A[adapter_name].weight + lm_lora_B = model.base_model.model.lm_head.lora_B[adapter_name].weight + + assert torch.allclose(embed_lora_A, lm_lora_B.T) + assert torch.allclose(embed_lora_B, lm_lora_A.T) + assert embed_lora_A.data_ptr() == lm_lora_B.data_ptr() + assert embed_lora_B.data_ptr() == lm_lora_A.data_ptr() + + @pytest.mark.parametrize( + "layer,tie_weights", + [ + ("lm_head", True), + ("lm_head", False), + ("embed_tokens", True), + ("embed_tokens", False), + (["embed_tokens", "lm_head"], True), + (["embed_tokens", "lm_head"], False), + ], + ) + def test_alt_weight_tying_tied_model_target_modules_lora(self, layer, tie_weights): + # When model weights are not tied, ensure a warning is raised even if + # the tied module name is present in `target_modules`. + model = self.get_lm_model(tie_weights=tie_weights) + embed_token_config = LoraConfig( + target_modules=["linear"] + layer if isinstance(layer, list) else [layer], + ensure_weight_tying=not tie_weights, ) + + if tie_weights: + wrn_msg = "`ensure_weight_tying` is not set to True" + else: + wrn_msg = "no tied modules were found in the model" + + with pytest.warns(UserWarning, match=wrn_msg): + model = get_peft_model(model, embed_token_config) + + if layer == "embed_tokens": + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) + elif layer == "lm_head": + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + elif layer == ["embed_tokens", "lm_head"]: + assert isinstance(model.base_model.model.model.embed_tokens, LoraLayer) + assert isinstance(model.base_model.model.lm_head, LoraLayer) + + adapter_name = "default" + + embed_lora_A = model.base_model.model.model.embed_tokens.lora_embedding_A[adapter_name] + embed_lora_B = model.base_model.model.model.embed_tokens.lora_embedding_B[adapter_name] + + lm_lora_A = model.base_model.model.lm_head.lora_A[adapter_name].weight + lm_lora_B = model.base_model.model.lm_head.lora_B[adapter_name].weight + + assert embed_lora_A.data_ptr() != lm_lora_B.data_ptr() + assert embed_lora_B.data_ptr() != lm_lora_A.data_ptr() + else: + raise NotImplementedError("Layer type {layer} is not supported for this test") + + def test_weight_tying_tied_model_target_modules_lokr(self): + model = self.get_lm_model() + + embed_token_config = LoKrConfig(target_modules=["linear", "lm_head"]) + + with pytest.warns(UserWarning, match="no implementation exists to tie the adapters"): + model = get_peft_model(model, embed_token_config) + + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) + assert isinstance(model.base_model.model.lm_head, LoKrLayer) diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index fc51a16676..a464f60aff 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -715,15 +715,6 @@ def test_weight_tying_applied_when_model_is_tied_standalone(self, model_weight_t assert merged_model.model.decoder.embed_tokens.weight.data_ptr() == merged_model.lm_head.weight.data_ptr() - def test_weight_tying_normally_issues_warning(self, model_weight_tied, recwarn): - # When using models with weight tying and targeting the embedding or the tied layer should raise a warning. - peft_config = LoraConfig(target_modules=["embed_tokens"]) - peft_model = get_peft_model(model_weight_tied, peft_config) - - warnings = [w.message.args[0] for w in recwarn] - warnings = [msg for msg in warnings if "Model with `tie_word_embeddings=True` and the" in msg] - assert warnings - def test_weight_tying_state_dict_ignores_tied_weights(self, model_weight_tied): # since weight tying is currently not supported make sure that an error is raised when attempting # to use a model that has tied input/output embeddings diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index e93f87610b..05edcf0e3f 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -1547,7 +1547,6 @@ def test_get_model_config_with_dataclass(self): class TestBaseTunerWarnForTiedEmbeddings: model_id = "peft-internal-testing/tiny-random-LlamaForCausalLM" - warn_end_inject = "huggingface/peft/issues/2018." warn_end_merge = ( "# Now use the original model but in untied format\n" "model = AutoModelForCausalLM.from_pretrained(untied_model_dir)\n```\n" @@ -1565,28 +1564,16 @@ def _get_peft_model(self, tie_word_embeddings, target_module): def _is_warn_triggered(self, warning_list, endswith): return any(str(warning.message).endswith(endswith) for warning in warning_list) - def test_warn_for_tied_embeddings_inject(self, recwarn): - self._get_peft_model(tie_word_embeddings=True, target_module="lm_head") - assert self._is_warn_triggered(recwarn.list, self.warn_end_inject) - def test_warn_for_tied_embeddings_merge(self, recwarn): model = self._get_peft_model(tie_word_embeddings=True, target_module="lm_head") model.merge_and_unload() assert self._is_warn_triggered(recwarn.list, self.warn_end_merge) - def test_no_warn_for_untied_embeddings_inject(self, recwarn): - self._get_peft_model(tie_word_embeddings=False, target_module="lm_head") - assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject) - def test_no_warn_for_untied_embeddings_merge(self, recwarn): model_not_tied = self._get_peft_model(tie_word_embeddings=False, target_module="lm_head") model_not_tied.merge_and_unload() assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge) - def test_no_warn_for_no_target_module_inject(self, recwarn): - self._get_peft_model(tie_word_embeddings=True, target_module="q_proj") - assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject) - def test_no_warn_for_no_target_module_merge(self, recwarn): model_no_target_module = self._get_peft_model(tie_word_embeddings=True, target_module="q_proj") model_no_target_module.merge_and_unload()