-
Couldn't load subscription status.
- Fork 6.4k
[LoRA] Implement hot-swapping of LoRA #9453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d3fbd7b
84bae62
63ece9d
94c669c
c7378ed
7c67b38
ea12e0d
ec4b0d5
e07323a
529a523
ac1346d
58b35ba
d21a988
488f2f0
ece3d0f
5ab1460
bc157e6
bd1da66
119a8ed
53c2f84
a715559
e40390d
1b834ec
4b01401
2cd3665
efbd820
e735ac2
69b637d
3a6677c
a96f3fd
deab0eb
2c6b435
ccb45f7
09e2ec7
67ab6bf
f03fe6b
2d407ca
6b59ecf
f14146f
a79876d
c3c1bdf
387ddf6
7f72d0b
dec4d10
204f521
716f446
4d82111
425cb39
115c77d
5d90753
62c1c13
d6d23b8
366632d
b181a47
f2a6146
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -66,7 +66,7 @@ class UNet2DConditionLoadersMixin: | |||||
| unet_name = UNET_NAME | ||||||
|
|
||||||
| @validate_hf_hub_args | ||||||
| def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): | ||||||
| def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], hotswap: bool = False, **kwargs): | ||||||
| r""" | ||||||
| Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be | ||||||
| defined in | ||||||
|
|
@@ -115,6 +115,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |||||
| `default_{i}` where i is the total number of adapters being loaded. | ||||||
| weight_name (`str`, *optional*, defaults to None): | ||||||
| Name of the serialized state dict file. | ||||||
| hotswap TODO | ||||||
|
|
||||||
| Example: | ||||||
|
|
||||||
|
|
@@ -209,6 +210,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |||||
| network_alphas=network_alphas, | ||||||
| adapter_name=adapter_name, | ||||||
| _pipeline=_pipeline, | ||||||
| hotswap=hotswap, | ||||||
| ) | ||||||
| else: | ||||||
| raise ValueError( | ||||||
|
|
@@ -268,7 +270,7 @@ def _process_custom_diffusion(self, state_dict): | |||||
|
|
||||||
| return attn_processors | ||||||
|
|
||||||
| def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): | ||||||
| def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, hotswap: bool = False): | ||||||
| # This method does the following things: | ||||||
| # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy | ||||||
| # format. For legacy format no filtering is applied. | ||||||
|
|
@@ -299,10 +301,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter | |||||
| state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict | ||||||
|
|
||||||
| if len(state_dict_to_be_used) > 0: | ||||||
| if adapter_name in getattr(self, "peft_config", {}): | ||||||
| if adapter_name in getattr(self, "peft_config", {}) and not hotswap: | ||||||
| raise ValueError( | ||||||
| f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." | ||||||
| ) | ||||||
| elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: | ||||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| raise ValueError(f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name.") | ||||||
|
|
||||||
| state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) | ||||||
|
|
||||||
|
|
@@ -336,8 +340,108 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter | |||||
| # otherwise loading LoRA weights will lead to an error | ||||||
| is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) | ||||||
|
|
||||||
| inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) | ||||||
| incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) | ||||||
|
|
||||||
| def _check_hotswap_configs_compatible(config0, config1): | ||||||
| # To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they | ||||||
| # use different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the | ||||||
| # weights from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these | ||||||
| # values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled. | ||||||
|
|
||||||
| # TODO: This is a very rough check at the moment and there are probably better ways than to error out | ||||||
| config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"] | ||||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| config0 = config0.to_dict() | ||||||
| config1 = config1.to_dict() | ||||||
| for key in config_keys_to_check: | ||||||
| val0 = config0[key] | ||||||
| val1 = config1[key] | ||||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| if val0 != val1: | ||||||
| raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") | ||||||
|
|
||||||
| def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name): | ||||||
|
||||||
| """ | ||||||
| Swap out the LoRA weights from the model with the weights from state_dict. | ||||||
|
|
||||||
| It is assumed that the existing adapter and the new adapter are compatible. | ||||||
|
|
||||||
| Args: | ||||||
| model: nn.Module | ||||||
| The model with the loaded adapter. | ||||||
| state_dict: dict[str, torch.Tensor] | ||||||
| The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). | ||||||
| adapter_name: Optional[str] | ||||||
| The name of the adapter that should be hot-swapped. | ||||||
|
|
||||||
| Raises: | ||||||
| RuntimeError | ||||||
| If the old and the new adapter are not compatible, a RuntimeError is raised. | ||||||
| """ | ||||||
| from operator import attrgetter | ||||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
|
||||||
| ####################### | ||||||
| # INSERT ADAPTER NAME # | ||||||
| ####################### | ||||||
|
|
||||||
| remapped_state_dict = {} | ||||||
| expected_str = adapter_name + "." | ||||||
| for key, val in state_dict.items(): | ||||||
| if expected_str not in key: | ||||||
| prefix, _, suffix = key.rpartition(".") | ||||||
| key = f"{prefix}.{adapter_name}.{suffix}" | ||||||
| remapped_state_dict[key] = val | ||||||
| state_dict = remapped_state_dict | ||||||
|
|
||||||
| #################### | ||||||
| # CHECK STATE_DICT # | ||||||
| #################### | ||||||
|
|
||||||
| # Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise | ||||||
| # hot-swapping is not possible | ||||||
| parameter_prefix = "lora_" # hard-coded for now | ||||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| is_compiled = hasattr(model, "_orig_mod") | ||||||
| # TODO: there is probably a more precise way to identify the adapter keys | ||||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} | ||||||
| unexpected_keys = set() | ||||||
|
|
||||||
| # first: dry run, not swapping anything | ||||||
| for key, new_val in state_dict.items(): | ||||||
| try: | ||||||
| old_val = attrgetter(key)(model) | ||||||
| except AttributeError: | ||||||
| unexpected_keys.add(key) | ||||||
| continue | ||||||
|
|
||||||
| if is_compiled: | ||||||
| missing_keys.remove("_orig_mod." + key) | ||||||
| else: | ||||||
| missing_keys.remove(key) | ||||||
|
|
||||||
| if missing_keys or unexpected_keys: | ||||||
| msg = "Hot swapping the adapter did not succeed." | ||||||
| if missing_keys: | ||||||
| msg += f" Missing keys: {', '.join(sorted(missing_keys))}." | ||||||
| if unexpected_keys: | ||||||
| msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}." | ||||||
| raise RuntimeError(msg) | ||||||
|
|
||||||
| ################### | ||||||
| # ACTUAL SWAPPING # | ||||||
| ################### | ||||||
|
|
||||||
| for key, new_val in state_dict.items(): | ||||||
| # no need to account for potential _orig_mod in key here, as torch handles that | ||||||
| old_val = attrgetter(key)(model) | ||||||
| old_val.data = new_val.data.to(device=old_val.device) | ||||||
|
||||||
| old_val.data = new_val.data.to(device=old_val.device) | |
| old_val.data.copy_ (new_val.data.to(device=old_val.device)) |
Uh oh!
There was an error while loading. Please reload this page.