- 
                Notifications
    
You must be signed in to change notification settings  - Fork 6.5k
 
Support for control-lora #10686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support for control-lora #10686
Changes from 10 commits
18de3ad
              e9d91e1
              9cf8ad7
              2453e14
              39b3b84
              de61226
              10daac7
              523967f
              dd24464
              33288e6
              280cf7f
              7c25a06
              0719c20
              81eed41
              2de1505
              ce2b34b
              6a1ff82
              ab9eeff
              6fff794
              8f7fc0a
              63bafc8
              c134bca
              39e9254
              d752992
              0a5bd74
              53a06cc
              23cba18
              d3a0755
              af8255e
              c6c13b6
              4a64d64
              59a42b2
              1c90272
              a2eff1c
              00a26cd
              1e8221c
              9d94c37
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| 
          
            
          
           | 
    @@ -25,6 +25,7 @@ | |||
| MIN_PEFT_VERSION, | ||||
| USE_PEFT_BACKEND, | ||||
| check_peft_version, | ||||
| convert_control_lora_state_dict_to_peft, | ||||
| convert_unet_state_dict_to_peft, | ||||
| delete_adapter_layers, | ||||
| get_adapter_name, | ||||
| 
          
            
          
           | 
    @@ -766,3 +767,184 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): | |||
| # Pop also the corresponding adapter from the config | ||||
| if hasattr(self, "peft_config"): | ||||
| self.peft_config.pop(adapter_name, None) | ||||
| 
     | 
||||
| 
     | 
||||
| class ControlLoRAMixin(PeftAdapterMixin): | ||||
| TARGET_MODULES = ["to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2", "proj_in", "proj_out", | ||||
| "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "linear_1", "linear_2", "time_emb_proj"] | ||||
| SAVE_MODULES = ["controlnet_cond_embedding.conv_in", "controlnet_cond_embedding.blocks.0", | ||||
| "controlnet_cond_embedding.blocks.1", "controlnet_cond_embedding.blocks.2", | ||||
| "controlnet_cond_embedding.blocks.3", "controlnet_cond_embedding.blocks.4", | ||||
| "controlnet_cond_embedding.blocks.5", "controlnet_cond_embedding.conv_out", | ||||
| "controlnet_down_blocks.0", "controlnet_down_blocks.1", "controlnet_down_blocks.2", | ||||
| "controlnet_down_blocks.3", "controlnet_down_blocks.4", "controlnet_down_blocks.5", | ||||
| "controlnet_down_blocks.6", "controlnet_down_blocks.7", "controlnet_down_blocks.8", | ||||
| "controlnet_mid_block", "norm", "norm1", "norm2", "norm3"] | ||||
| 
     | 
||||
| def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): | ||||
| from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict | ||||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||||
| 
     | 
||||
| cache_dir = kwargs.pop("cache_dir", None) | ||||
| force_download = kwargs.pop("force_download", False) | ||||
| proxies = kwargs.pop("proxies", None) | ||||
| local_files_only = kwargs.pop("local_files_only", None) | ||||
| token = kwargs.pop("token", None) | ||||
| revision = kwargs.pop("revision", None) | ||||
| subfolder = kwargs.pop("subfolder", None) | ||||
| weight_name = kwargs.pop("weight_name", None) | ||||
| use_safetensors = kwargs.pop("use_safetensors", None) | ||||
| adapter_name = kwargs.pop("adapter_name", None) | ||||
| network_alphas = kwargs.pop("network_alphas", None) | ||||
| _pipeline = kwargs.pop("_pipeline", None) | ||||
| low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) | ||||
| allow_pickle = False | ||||
| 
     | 
||||
| if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): | ||||
| raise ValueError( | ||||
| "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | ||||
| ) | ||||
| 
     | 
||||
| user_agent = { | ||||
| "file_type": "attn_procs_weights", | ||||
| "framework": "pytorch", | ||||
| } | ||||
| 
     | 
||||
| state_dict = _fetch_state_dict( | ||||
| pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, | ||||
| weight_name=weight_name, | ||||
| use_safetensors=use_safetensors, | ||||
| local_files_only=local_files_only, | ||||
| cache_dir=cache_dir, | ||||
| force_download=force_download, | ||||
| proxies=proxies, | ||||
| token=token, | ||||
| revision=revision, | ||||
| subfolder=subfolder, | ||||
| user_agent=user_agent, | ||||
| allow_pickle=allow_pickle, | ||||
| ) | ||||
| if network_alphas is not None and prefix is None: | ||||
| raise ValueError("`network_alphas` cannot be None when `prefix` is None.") | ||||
| 
     | 
||||
| if prefix is not None: | ||||
| keys = list(state_dict.keys()) | ||||
| model_keys = [k for k in keys if k.startswith(f"{prefix}.")] | ||||
| if len(model_keys) > 0: | ||||
| state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} | ||||
| 
     | 
||||
| if len(state_dict) > 0: | ||||
| if adapter_name in getattr(self, "peft_config", {}): | ||||
| raise ValueError( | ||||
| f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." | ||||
| ) | ||||
| 
     | 
||||
| # check with first key if is not in peft format | ||||
| if "lora_controlnet" in state_dict: | ||||
| del state_dict["lora_controlnet"] | ||||
| state_dict = convert_control_lora_state_dict_to_peft(state_dict) | ||||
                
       | 
||||
| def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): | 
method work as expected on the Control LoRA state dict, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another change is to forcibly set the adapter_name to "default".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be a breaking change as we support loading multiple adapters. If this is the only change that is required, I think we can simply port it over to load_lora_adapters() of PeftAdapterMixin. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is reasonable. This will be addressed after resolving the other issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During LoRA automated loading, it is necessary to specify the modules to be loaded, which is not possible in the original code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we show me where these are required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight_name.txt
This is the original weight file of Control-Lora. By comparing and analyzing the Diffusers format with its format, we concluded that we need to use LoRA to fine-tune certain modules while also training other modules. This implementation can also be found at https://github.com/lavinal712/control-lora-v3/blob/main/train_control_lora_sdxl.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I am asking where are these used in the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line 868-873
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we should be able to infer that without having to directly specify it like this. This is what is done for the others:
diffusers/src/diffusers/utils/peft_utils.py
Line 150 in 464374f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function did not achieve the expected effect, so I resorted to modifying it forcefully to meet my purpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we need to find a way to tackle this problem. It should not deviate too much from how we go about loading other LoRAs.