-
Notifications
You must be signed in to change notification settings - Fork 2.2k
ENH: Tie weights for target_modules in Lora (#2864) #2879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
4c6d15f
4b91220
46b803e
37b1e06
8388aa8
cd6c6d0
0cb44e8
628ce10
602ce10
e2d0345
7880032
f73af50
46cca1e
2267a48
5d5b8e4
c7cfe40
8294ec7
7370a21
1da895f
d86ff7d
dc03dd4
c79a64c
0715451
dbb0096
06d4b7f
67a71d6
8889558
9f7702f
4d5d681
ba4d81f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,6 +187,15 @@ def _create_and_replace( | |
| r = lora_config.rank_pattern.get(r_key, lora_config.r) | ||
| alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) | ||
|
|
||
| is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) | ||
romitjain marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tied_adapters = {} | ||
| if is_tied: | ||
| tied_module = self.model.get_input_embeddings() | ||
| emb_A = tied_module.lora_embedding_A[adapter_name] | ||
| emb_B = tied_module.lora_embedding_B[adapter_name] | ||
|
|
||
| tied_adapters = {"lora_A": emb_B.t(), "lora_B": emb_A.t()} | ||
|
|
||
| kwargs = { | ||
| "r": r, | ||
| "lora_alpha": alpha, | ||
|
|
@@ -204,6 +213,8 @@ def _create_and_replace( | |
| "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), | ||
| "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), | ||
| "parameter_name": parameter_name, | ||
| "is_tied": is_tied, | ||
| "tied_adapters": tied_adapters, | ||
| } | ||
|
|
||
| # for torchao merging, we need the get_apply_tensor_subclass from the quantization config | ||
|
|
@@ -249,9 +260,17 @@ def _create_and_replace( | |
| if adapter_name not in self.active_adapters: | ||
| # adding an additional adapter: it is not automatically trainable | ||
| new_module.requires_grad_(False) | ||
| self._replace_module(parent, target_name, new_module, target) | ||
|
|
||
| def _replace_module(self, parent, child_name, new_module, child): | ||
| self._replace_module( | ||
| parent=parent, | ||
| child_name=target_name, | ||
| new_module=new_module, | ||
| child=target, | ||
| is_tied=is_tied, | ||
| adapter_name=adapter_name, | ||
| ) | ||
|
|
||
| def _replace_module(self, parent, child_name, new_module, child, is_tied, adapter_name): | ||
| # override in LoraModel to handle quantized weights properly | ||
|
|
||
| setattr(parent, child_name, new_module) | ||
|
|
@@ -806,8 +825,36 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap | |
|
|
||
| return tensors_lora | ||
|
|
||
| def _add_modules_to_tie(self, peft_config, tied_weight_keys): | ||
| modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) | ||
| missing_keys = set(tied_weight_keys) - modules_to_save | ||
| def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the | ||
| tied layers from `module_to_save`. Maintain a separate set for layers to be tied | ||
|
|
||
| Args: | ||
| peft_config (LoraConfig): _description_ | ||
| tied_weight_keys (list[str]): _description_ | ||
| """ | ||
| tied_weight_keys = set(tied_weight_keys) | ||
| setattr(peft_config, "modules_to_tie", tied_weight_keys) | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| modules_to_save = getattr(peft_config, "modules_to_save", []) or [] | ||
| if "embed_tokens" not in modules_to_save: | ||
|
||
| modules_to_save.append("embed_tokens") | ||
|
|
||
| for m in tied_weight_keys: | ||
| if m in modules_to_save: | ||
| modules_to_save.remove(m) | ||
|
||
|
|
||
| setattr(peft_config, "modules_to_save", modules_to_save) | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def _add_targets_to_tie(self, peft_config, tied_weight_keys): | ||
| tied_weight_keys = set(tied_weight_keys) | ||
| setattr(peft_config, "target_modules_to_tie", tied_weight_keys) | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| target_modules = set(getattr(peft_config, "target_modules", []) or []) | ||
|
||
| target_modules.add("embed_tokens") | ||
|
|
||
| for m in tied_weight_keys: | ||
| target_modules.add(m) | ||
|
|
||
| peft_config.modules_to_tie = missing_keys | ||
| setattr(peft_config, "target_modules", target_modules) | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -363,6 +363,30 @@ def _prepare_model(self, peft_config: PeftConfig, model: nn.Module): | |
| """ | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def _check_tied_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: | ||
| """ | ||
| A helper method to check if the passed module's key name matches any of the tied modules | ||
|
|
||
| Args: | ||
| config (`PeftConfig`): | ||
| A config to match target modules from. | ||
| key (`str`): | ||
| A key to search any matches in config. | ||
|
|
||
| Returns: | ||
| `bool` | ||
| True if key matches any tied modules from config, False if no match found. | ||
| """ | ||
| _target_modules_to_tie = getattr(peft_config, "target_modules_to_tie", {}) or {} | ||
|
|
||
| if key in _target_modules_to_tie or any( | ||
| key.endswith(f".{target_key}") for target_key in _target_modules_to_tie | ||
| ): | ||
| return True | ||
|
|
||
| return False | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @staticmethod | ||
| def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool | re.Match[str] | None: | ||
| """ | ||
|
|
@@ -699,6 +723,7 @@ def inject_adapter( | |
| excluded_modules = [] | ||
| unmatched_modules = [] | ||
| targeted_modules_from_peft_config: list[str] = [] # only relevant if state_dict is passed | ||
| targets_to_tie: list[str] = [] | ||
| # Note: If possible, all checks should be performed *at the start of this method*. | ||
| # This way, we can raise early if something goes wrong, without leaving the model | ||
| # in a bad (half-initialized) state. | ||
|
|
@@ -792,6 +817,9 @@ def inject_adapter( | |
| elif not result: | ||
| unmatched_modules.append(key) | ||
| else: | ||
| if self._check_tied_module_exists(peft_config, key): | ||
| targets_to_tie.append(key) | ||
| continue | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.targeted_module_names.append(key) | ||
| parent, target, target_name = _get_submodules(model, key) | ||
| self._check_target_module_compatiblity(peft_config, model, target_name) | ||
|
|
@@ -805,6 +833,9 @@ def inject_adapter( | |
| if key not in module_names: | ||
| unmatched_modules.append(key) | ||
| else: | ||
| if self._check_tied_module_exists(peft_config, key): | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| targets_to_tie.append(key) | ||
| continue | ||
| self.targeted_module_names.append(key) | ||
| parent, target, target_name = _get_submodules(model, key) | ||
| self._check_target_module_compatiblity(peft_config, model, target_name) | ||
|
|
@@ -824,6 +855,15 @@ def inject_adapter( | |
| peft_config=peft_config, model=model, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage | ||
| ) | ||
|
|
||
| # Another loop for tying target modules | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for key in targets_to_tie: | ||
| self.targeted_module_names.append(key) | ||
| parent, target, target_name = _get_submodules(model, key) | ||
| self._check_target_module_compatiblity(peft_config, model, target_name) | ||
| ctx = init_empty_weights if low_cpu_mem_usage else nullcontext | ||
| with ctx(): | ||
| self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) | ||
|
|
||
| #################### | ||
| # CHECK FOR ERRORS # | ||
| #################### | ||
|
|
@@ -910,15 +950,6 @@ def inject_adapter( | |
| RuntimeWarning, | ||
| ) | ||
|
|
||
| tied_target_modules = self._get_tied_target_modules(model=model) | ||
romitjain marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if tied_target_modules: | ||
| warnings.warn( | ||
| f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. " | ||
| "This can lead to complications, for example when merging the adapter " | ||
| "or converting your model to formats other than safetensors. " | ||
| "See for example https://github.com/huggingface/peft/issues/2018." | ||
| ) | ||
|
|
||
| ################ | ||
| # HOUSEKEEPING # | ||
| ################ | ||
|
|
@@ -1198,6 +1229,24 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys): | |
| """ | ||
| This method adds modules to tie to `peft_config` so that those modules can be tied downstream. By default this | ||
| method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. | ||
|
|
||
| Check `peft.tuners.lora.LoraModel._add_modules_to_tie` for an example. | ||
| """ | ||
| msg = ( | ||
| "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "but no implementation exists to tie the adapters. " | ||
| "This can lead to complications, for example when merging the adapter " | ||
| "or converting your model to formats other than safetensors. " | ||
| "Check the discussion here: https://github.com/huggingface/peft/issues/2777" | ||
| ) | ||
| warnings.warn(msg) | ||
|
|
||
| def _add_targets_to_tie(self, peft_config, tied_weight_keys): | ||
| """ | ||
| This method adds targets to tie to `peft_config` so that those modules can be tied downstream. By default this | ||
| method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this. | ||
|
|
||
| Check `peft.tuners.lora.LoraModel._add_targets_to_tie` for an example. | ||
| """ | ||
| msg = ( | ||
| "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " | ||
|
|
@@ -1210,27 +1259,33 @@ def _add_modules_to_tie(self, peft_config, tied_weight_keys): | |
|
|
||
| def _check_tied_modules(self, model: nn.Module, peft_config): | ||
| """ | ||
| Checks if any of the tied layers are targetted via `modules_to_save`. Updates the `peft_config.modules_to_tie` | ||
| with any layers that needs to be tied | ||
| Checks if any of the tied layers are targetted via `modules_to_save` or `target_modules`. Updates the | ||
| `peft_config` in place with any layers/adapters that needs to be tied | ||
| """ | ||
| modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) | ||
| is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save) | ||
|
|
||
| target_modules = set(getattr(peft_config, "target_modules", []) or []) | ||
| is_embedding_in_target = any(m in EMBEDDING_LAYER_NAMES for m in target_modules) | ||
|
|
||
| tied_weight_keys = self._get_tied_weight_keys(model) | ||
|
|
||
| if getattr(peft_config, "ensure_weight_tying", False): | ||
| if is_embedding_to_save and tied_weight_keys: | ||
| self._add_modules_to_tie(peft_config, tied_weight_keys) | ||
| if (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: | ||
|
||
| if is_embedding_to_save: | ||
| self._add_modules_to_tie(peft_config, tied_weight_keys) | ||
| elif is_embedding_in_target: | ||
| self._add_targets_to_tie(peft_config, tied_weight_keys) | ||
|
|
||
| elif not is_embedding_to_save and tied_weight_keys: | ||
| elif not (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: | ||
| warnings.warn( | ||
| "You have requested `ensure_weight_tying`, but no tied modules are added in `modules_to_save`" | ||
| "You have requested `ensure_weight_tying`, but no tied modules are added in either `modules_to_save` or `target_modules`" | ||
| ) | ||
|
|
||
| elif not tied_weight_keys: | ||
| warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model") | ||
|
|
||
| elif is_embedding_to_save and tied_weight_keys: | ||
| elif (is_embedding_to_save or is_embedding_in_target) and tied_weight_keys: | ||
| if hasattr(peft_config, "ensure_weight_tying"): | ||
| msg = ( | ||
| "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.