-
Notifications
You must be signed in to change notification settings - Fork 2.2k
ENH: Tie weights for target_modules in Lora (#2864) #2879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
4c6d15f
4b91220
46b803e
37b1e06
8388aa8
cd6c6d0
0cb44e8
628ce10
602ce10
e2d0345
7880032
f73af50
46cca1e
2267a48
5d5b8e4
c7cfe40
8294ec7
7370a21
1da895f
d86ff7d
dc03dd4
c79a64c
0715451
dbb0096
06d4b7f
67a71d6
8889558
9f7702f
4d5d681
ba4d81f
e399072
d8a8edf
167bdce
3f16c36
19929f7
d5082ff
d3e954f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,6 +187,17 @@ def _create_and_replace( | |
| r = lora_config.rank_pattern.get(r_key, lora_config.r) | ||
| alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) | ||
|
|
||
| # Checks if the target is marked as a tied layer | ||
| # If true, we add the reference to lora adapters of embedding layer in `tied_adapters` | ||
| is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) | ||
romitjain marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tied_adapters = {} | ||
| if is_tied: | ||
| tied_module = self.model.get_input_embeddings() | ||
| emb_A = tied_module.lora_embedding_A[adapter_name] | ||
| emb_B = tied_module.lora_embedding_B[adapter_name] | ||
|
|
||
| tied_adapters = {"lora_A": emb_B.t(), "lora_B": emb_A.t()} | ||
|
|
||
| kwargs = { | ||
| "r": r, | ||
| "lora_alpha": alpha, | ||
|
|
@@ -204,6 +215,7 @@ def _create_and_replace( | |
| "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), | ||
| "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), | ||
| "parameter_name": parameter_name, | ||
| "tied_adapters": tied_adapters, | ||
| } | ||
|
|
||
| # for torchao merging, we need the get_apply_tensor_subclass from the quantization config | ||
|
|
@@ -249,6 +261,7 @@ def _create_and_replace( | |
| if adapter_name not in self.active_adapters: | ||
| # adding an additional adapter: it is not automatically trainable | ||
| new_module.requires_grad_(False) | ||
|
|
||
| self._replace_module(parent, target_name, new_module, target) | ||
|
|
||
| def _replace_module(self, parent, child_name, new_module, child): | ||
|
|
@@ -806,8 +819,56 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap | |
|
|
||
| return tensors_lora | ||
|
|
||
| def _add_modules_to_tie(self, peft_config, tied_weight_keys): | ||
| modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) | ||
| missing_keys = set(tied_weight_keys) - modules_to_save | ||
| def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the | ||
| tied layers from `module_to_save`. Maintain a separate set for layers to be tied | ||
|
|
||
| Args: | ||
| peft_config (LoraConfig) | ||
| tied_weight_keys (list[str]) | ||
| """ | ||
romitjain marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tied_weight_keys = set(tied_weight_keys) | ||
| peft_config.modules_to_tie = tied_weight_keys | ||
|
|
||
| modules_to_save = getattr(peft_config, "modules_to_save", []) or [] | ||
| if "embed_tokens" not in modules_to_save: | ||
|
||
| modules_to_save.append("embed_tokens") | ||
|
|
||
| for m in tied_weight_keys: | ||
| if m in modules_to_save: | ||
| modules_to_save.remove(m) | ||
|
||
|
|
||
| peft_config.modules_to_save = modules_to_save | ||
|
|
||
| def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): | ||
| """ | ||
| Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the | ||
| tied layers from `target_modules`. Maintain a separate set for layers to be tied | ||
|
|
||
| Args: | ||
| peft_config (LoraConfig) | ||
| tied_weight_keys (list[str]) | ||
| """ | ||
| tied_weight_keys = set(tied_weight_keys) | ||
| peft_config.target_modules_to_tie = tied_weight_keys | ||
|
|
||
| raw_target_modules = getattr(peft_config, "target_modules", None) | ||
githubnemo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if isinstance(raw_target_modules, str): | ||
githubnemo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # The way weight tying is handled for adapters, we always want to add | ||
| # lora adapters to the input embedding layer (embed_tokens) | ||
| # instead of output embedding lauyer. | ||
| if "lm_head" in raw_target_modules: | ||
| raw_target_modules = raw_target_modules.replace("lm_head", "embed_tokens") | ||
| peft_config.target_modules = raw_target_modules | ||
| return | ||
|
|
||
| target_modules = set(raw_target_modules or []) | ||
| target_modules.add("embed_tokens") | ||
|
|
||
| for m in tied_weight_keys: | ||
| if m in target_modules: | ||
| target_modules.remove(m) | ||
|
||
|
|
||
| peft_config.modules_to_tie = missing_keys | ||
| peft_config.target_modules = target_modules | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -363,6 +363,30 @@ def _prepare_model(self, peft_config: PeftConfig, model: nn.Module): | |
| """ | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def _check_tied_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: | ||
| """ | ||
| A helper method to check if the passed module's key name matches any of the tied modules | ||
|
|
||
| Args: | ||
| config (`PeftConfig`): | ||
| A config to match target modules from. | ||
| key (`str`): | ||
| A key to search any matches in config. | ||
|
|
||
| Returns: | ||
| `bool` | ||
| True if key matches any tied modules from config, False if no match found. | ||
| """ | ||
| _target_modules_to_tie = getattr(peft_config, "target_modules_to_tie", {}) or {} | ||
|
|
||
| if key in _target_modules_to_tie or any( | ||
| key.endswith(f".{target_key}") for target_key in _target_modules_to_tie | ||
| ): | ||
| return True | ||
|
|
||
| return False | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @staticmethod | ||
| def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: | ||
| """ | ||
|
|
@@ -699,6 +723,7 @@ def inject_adapter( | |
| excluded_modules = [] | ||
| unmatched_modules = [] | ||
| targeted_modules_from_peft_config: list[str] = [] # only relevant if state_dict is passed | ||
| targets_to_tie: list[str] = [] | ||
| # Note: If possible, all checks should be performed *at the start of this method*. | ||
| # This way, we can raise early if something goes wrong, without leaving the model | ||
| # in a bad (half-initialized) state. | ||
|
|
@@ -787,6 +812,13 @@ def inject_adapter( | |
| if state_dict is None: | ||
| # normal mechanism: match the modules using the peft_config | ||
| result = self._check_target_module_exists(peft_config, key) | ||
| # If the module is a tied layer, then we skip injecting | ||
| # any adapter here and tie it later to the adapter of the source layer. | ||
| # In this loop we only add adapters to the source layer (eg: embed_tokens) | ||
| # Only applicable if `ensure_weight_tying = True` for LoraConfig | ||
| if self._check_tied_module_exists(peft_config, key): | ||
| targets_to_tie.append(key) | ||
| continue | ||
| if isinstance(result, _ExcludedModule): | ||
| excluded_modules.append(key) | ||
| elif not result: | ||
|
|
@@ -805,6 +837,13 @@ def inject_adapter( | |
| if key not in module_names: | ||
| unmatched_modules.append(key) | ||
| else: | ||
| # If the module is a tied layer, then we skip injecting | ||
| # any adapter here and tie it later to the adapter of the source layer. | ||
| # In this loop we only add adapters to the source layer (eg: embed_tokens) | ||
| # Only applicable if `ensure_weight_tying = True` for LoraConfig | ||
| if self._check_tied_module_exists(peft_config, key): | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| targets_to_tie.append(key) | ||
| continue | ||
| self.targeted_module_names.append(key) | ||
| parent, target, target_name = _get_submodules(model, key) | ||
| self._check_target_module_compatiblity(peft_config, model, target_name) | ||
|
|
@@ -824,6 +863,16 @@ def inject_adapter( | |
| peft_config=peft_config, model=model, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage | ||
| ) | ||
|
|
||
| # Here we inject tied adapters for all the layers which were tied | ||
| # Only applicable if `ensure_weight_tying = True` for LoraConfig | ||
| for key in targets_to_tie: | ||
| self.targeted_module_names.append(key) | ||
| parent, target, target_name = _get_submodules(model, key) | ||
| self._check_target_module_compatiblity(peft_config, model, target_name) | ||
| ctx = init_empty_weights if low_cpu_mem_usage else nullcontext | ||
| with ctx(): | ||
| self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) | ||
|
|
||
| #################### | ||
| # CHECK FOR ERRORS # | ||
| #################### | ||
|
|
@@ -910,15 +959,6 @@ def inject_adapter( | |
| RuntimeWarning, | ||
| ) | ||
|
|
||
| tied_target_modules = self._get_tied_target_modules(model=model) | ||
romitjain marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if tied_target_modules: | ||
| warnings.warn( | ||
| f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. " | ||
| "This can lead to complications, for example when merging the adapter " | ||
| "or converting your model to formats other than safetensors. " | ||
| "See for example https://github.com/huggingface/peft/issues/2018." | ||
| ) | ||
|
|
||
| ################ | ||
| # HOUSEKEEPING # | ||
| ################ | ||
|
|
@@ -1198,6 +1238,24 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys): | |
| """ | ||
| This method adds modules to tie to `peft_config` so that those modules can be tied downstream. By default this | ||
| method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. | ||
|
|
||
| Check `peft.tuners.lora.LoraModel._add_modules_to_tie` for an example. | ||
| """ | ||
| msg = ( | ||
| "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "but no implementation exists to tie the adapters. " | ||
| "This can lead to complications, for example when merging the adapter " | ||
| "or converting your model to formats other than safetensors. " | ||
| "Check the discussion here: https://github.com/huggingface/peft/issues/2777" | ||
| ) | ||
| warnings.warn(msg) | ||
|
|
||
| def _add_targets_to_tie(self, peft_config, tied_weight_keys): | ||
| """ | ||
| This method adds targets to tie to `peft_config` so that those modules can be tied downstream. By default this | ||
| method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. | ||
|
|
||
| Check `peft.tuners.lora.LoraModel._add_targets_to_tie` for an example. | ||
| """ | ||
| msg = ( | ||
| "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " | ||
|
|
@@ -1210,27 +1268,39 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys): | |
|
|
||
| def _check_tied_modules(self, model: nn.Module, peft_config): | ||
| """ | ||
| Checks if any of the tied layers are targetted via `modules_to_save`. Updates the `peft_config.modules_to_tie` | ||
| with any layers that needs to be tied | ||
| Checks if any of the tied layers are targetted via `modules_to_save` or `target_modules`. Updates the | ||
| `peft_config` in place with any layers/adapters that needs to be tied | ||
| """ | ||
| modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) | ||
| is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save) | ||
|
|
||
| raw_target_modules = getattr(peft_config, "target_modules", None) | ||
| if isinstance(raw_target_modules, str): | ||
| is_embedding_in_target = any( | ||
| match_target_against_key(raw_target_modules, m) for m in EMBEDDING_LAYER_NAMES | ||
| ) | ||
| else: | ||
| target_modules = set(raw_target_modules or []) | ||
| is_embedding_in_target = any(m in EMBEDDING_LAYER_NAMES for m in target_modules) | ||
|
|
||
| tied_weight_keys = self._get_tied_weight_keys(model) | ||
|
|
||
| if getattr(peft_config, "ensure_weight_tying", False): | ||
| if is_embedding_to_save and tied_weight_keys: | ||
| self._add_modules_to_tie(peft_config, tied_weight_keys) | ||
| if (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: | ||
|
||
| if is_embedding_to_save: | ||
| self._add_modules_to_tie(peft_config, tied_weight_keys) | ||
| elif is_embedding_in_target: | ||
| self._add_targets_to_tie(peft_config, tied_weight_keys) | ||
|
|
||
| elif not is_embedding_to_save and tied_weight_keys: | ||
| elif not (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: | ||
| warnings.warn( | ||
| "You have requested `ensure_weight_tying`, but no tied modules are added in `modules_to_save`" | ||
| "You have requested `ensure_weight_tying`, but no tied modules are added in either `modules_to_save` or `target_modules`" | ||
| ) | ||
|
|
||
| elif not tied_weight_keys: | ||
| warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model") | ||
|
|
||
| elif is_embedding_to_save and tied_weight_keys: | ||
| elif (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: | ||
| if hasattr(peft_config, "ensure_weight_tying"): | ||
| msg = ( | ||
| "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.