- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.1k
[WIP] ENH: Adapter injection based on state_dict #2637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0219529
              c8b96b8
              7bcea4c
              73f53b1
              ae2258a
              bac32c6
              28bf441
              699757d
              34a5bc0
              a04ae0d
              32a1fb2
              571a055
              8d46c76
              d7e9436
              cc379fe
              47c4a6d
              e253b59
              ecfdab8
              a36a653
              83ec3c5
              7d73b01
              ecf10f8
              1264b9d
              1683838
              0241bc6
              fe1223c
              3f0c9bd
              54c1364
              53b5da1
              eacb767
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -154,6 +154,8 @@ class BaseTuner(nn.Module, ABC): | |
|  | ||
| The easiest is to check what is done in the `peft.tuners.lora.LoraModel` class. | ||
|  | ||
| FIXME | ||
|  | ||
| Attributes: | ||
| model (`torch.nn.Module`): | ||
| The model to which the adapter tuner layers will be attached. | ||
|  | @@ -176,6 +178,7 @@ def __init__( | |
| peft_config: Union[PeftConfig, dict[str, PeftConfig]], | ||
| adapter_name: str, | ||
| low_cpu_mem_usage: bool = False, | ||
| state_dict: Optional[dict[str, torch.Tensor]] = None, | ||
| ) -> None: | ||
| super().__init__() | ||
|  | ||
|  | @@ -200,7 +203,7 @@ def __init__( | |
| self.active_adapter: str | list[str] = adapter_name | ||
| self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name) | ||
| if peft_config != PeftType.XLORA or peft_config[adapter_name] != PeftType.XLORA: | ||
| self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) | ||
| self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage, state_dict=state_dict) | ||
|  | ||
| # Copy the peft_config in the injected model. | ||
| self.model.peft_config = self.peft_config | ||
|  | @@ -427,14 +430,21 @@ def _check_target_module_compatiblity(self, peft_config: PeftConfig, model: nn.M | |
| _check_lora_target_modules_mamba(peft_config, model, target_name) | ||
|  | ||
| def inject_adapter( | ||
| self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False | ||
| self, | ||
| model: nn.Module, | ||
| adapter_name: str, | ||
| autocast_adapter_dtype: bool = True, | ||
| low_cpu_mem_usage: bool = False, | ||
| state_dict: Optional[dict[str, torch.Tensor]] = None, | ||
| ) -> None: | ||
| r""" | ||
| Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the | ||
| hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed. | ||
|  | ||
| The corresponding PEFT config is directly retrieved from the `peft_config` attribute of the BaseTuner class. | ||
|  | ||
| FIXME | ||
|  | ||
| Args: | ||
| model (`nn.Module`): | ||
| The model to be tuned. | ||
|  | @@ -444,11 +454,17 @@ def inject_adapter( | |
| Whether to autocast the adapter dtype. Defaults to `True`. | ||
| low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): | ||
| Create empty adapter weights on meta device. Useful to speed up the loading process. | ||
| FIXME | ||
|  | ||
| """ | ||
| ################################### | ||
| # PREPARATION OF MODEL AND CONFIG # | ||
| ################################### | ||
|  | ||
| peft_config = self.peft_config[adapter_name] | ||
| excluded_modules = [] | ||
| unmatched_modules = [] | ||
| targeted_modules_from_peft_config: list[str] = [] # only relevant if state_dict is passed | ||
| # Note: If possible, all checks should be performed *at the start of this method*. | ||
| # This way, we can raise early if something goes wrong, without leaving the model | ||
| # in a bad (half-initialized) state. | ||
|  | @@ -498,11 +514,18 @@ def inject_adapter( | |
| if len(new_target_modules) < len(peft_config.target_modules): | ||
| peft_config.target_modules = new_target_modules | ||
|  | ||
| ############################### | ||
| # MATCHING & CREATING MODULES # | ||
| ############################### | ||
|  | ||
| existing_adapter_map = {} | ||
| for key, module in named_modules: | ||
| if isinstance(module, BaseTunerLayer): | ||
| existing_adapter_map[key] = module | ||
|  | ||
| # TODO: check if this the most robust way | ||
| state_dict_keys = {k.rsplit(".", 2)[0] for k in state_dict} if state_dict is not None else set() | ||
|  | ||
| for key, module in named_modules: | ||
| if not key: | ||
| continue | ||
|  | @@ -517,18 +540,67 @@ def inject_adapter( | |
| if excluded_modules and excluded_modules[-1] == key: | ||
| continue | ||
|  | ||
| result = self._check_target_module_exists(peft_config, key) | ||
| if isinstance(result, _ExcludedModule): | ||
| excluded_modules.append(key) | ||
| elif not result: | ||
| unmatched_modules.append(key) | ||
| if state_dict is None: | ||
| # normal mechanism: match the modules using the peft_config | ||
| result = self._check_target_module_exists(peft_config, key) | ||
| if isinstance(result, _ExcludedModule): | ||
| excluded_modules.append(key) | ||
| elif not result: | ||
| unmatched_modules.append(key) | ||
| else: | ||
| self.targeted_module_names.append(key) | ||
| parent, target, target_name = _get_submodules(model, key) | ||
| self._check_target_module_compatiblity(peft_config, model, target_name) | ||
| ctx = init_empty_weights if low_cpu_mem_usage else nullcontext | ||
| with ctx(): | ||
| self._create_and_replace( | ||
| peft_config, adapter_name, target, target_name, parent, current_key=key | ||
| ) | ||
| 
      Comment on lines
    
      +567
     to 
      +580
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to reviewers: This is the exact same code as before, just indented by one level. | ||
| else: | ||
| self.targeted_module_names.append(key) | ||
| parent, target, target_name = _get_submodules(model, key) | ||
| self._check_target_module_compatiblity(peft_config, model, target_name) | ||
| ctx = init_empty_weights if low_cpu_mem_usage else nullcontext | ||
| with ctx(): | ||
| self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) | ||
| # use the state_dict to match modules instead | ||
| if key not in state_dict_keys: | ||
|          | ||
| unmatched_modules.append(key) | ||
| else: | ||
| self.targeted_module_names.append(key) | ||
| parent, target, target_name = _get_submodules(model, key) | ||
| self._check_target_module_compatiblity(peft_config, model, target_name) | ||
| ctx = init_empty_weights if low_cpu_mem_usage else nullcontext | ||
| with ctx(): | ||
| self._create_and_replace( | ||
| peft_config, adapter_name, target, target_name, parent, current_key=key | ||
| ) | ||
|  | ||
| # still record what would have been matched via the config so that the two results can be compared | ||
| if self._check_target_module_exists(peft_config, key): | ||
| targeted_modules_from_peft_config.append(key) | ||
| self.targeted_module_names.append(key) | ||
|  | ||
| #################### | ||
| # CHECK FOR ERRORS # | ||
| #################### | ||
|  | ||
| if state_dict is not None: | ||
| # in case that the state_dict was used as source of truth and it resulted in different outcomes than what | ||
| # would have been matched with the PEFT config, warn the user about that. | ||
| targeted_set_from_peft_config = set(targeted_modules_from_peft_config) | ||
| targeted_set_from_state_dict = set(self.targeted_module_names) | ||
| diff_peft_config = targeted_set_from_peft_config - targeted_set_from_state_dict | ||
| diff_state_dict = targeted_set_from_state_dict - targeted_set_from_peft_config | ||
| error_msg = "" | ||
| if diff_peft_config or diff_state_dict: | ||
| error_msg = ( | ||
| "While injecting the PEFT adapters, an inconsistency was discovered between the PEFT config and " | ||
| "the provided state_dict. This is not necessarily an issue and can be ignored if this was the " | ||
| "intent. " | ||
| ) | ||
| if diff_peft_config: | ||
| error_msg += f"The PEFT config contained these additional target modules: {sorted(diff_peft_config)}. " | ||
| if diff_state_dict: | ||
| error_msg += f"The state_dict contained these additional target modules: {sorted(diff_state_dict)}. " | ||
| if error_msg: | ||
| # FIXME for debugging purposes, raise here | ||
| 1/0 | ||
| warnings.warn(error_msg) | ||
|  | ||
| if not self.targeted_module_names and not uses_dummy_target_modules: | ||
| if excluded_modules and not unmatched_modules: | ||
|  | @@ -578,6 +650,10 @@ def inject_adapter( | |
| "See for example https://github.com/huggingface/peft/issues/2018." | ||
| ) | ||
|  | ||
| ################ | ||
| # HOUSEKEEPING # | ||
| ################ | ||
|  | ||
| # It's important to set the adapter here (again), because otherwise it can happen that if a 2nd adapter is | ||
| # added, and it targets different layer(s) than the first adapter (which is active), then those different | ||
| # layers will be activated, which we don't want. | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these state dict keys or module keys? I think it's the latter no? For example,
foo.bar.weight. Here the module and submodule arefooandbar. If we were to just obtain state dict keys, a simplestate_dict.keys()would have sufficed. So, I think we should consider renaming it.