@@ -3870,6 +3870,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
38703870 super ().unfuse_lora (components = components )
38713871
38723872
3873+ class HunyuanVideoLoraLoaderMixin (LoraBaseMixin ):
3874+ r"""
3875+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
3876+ """
3877+
3878+ _lora_loadable_modules = ["transformer" ]
3879+ transformer_name = TRANSFORMER_NAME
3880+
3881+ @classmethod
3882+ @validate_hf_hub_args
3883+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3884+ def lora_state_dict (
3885+ cls ,
3886+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3887+ ** kwargs ,
3888+ ):
3889+ r"""
3890+ Return state dict for lora weights and the network alphas.
3891+
3892+ <Tip warning={true}>
3893+
3894+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3895+
3896+ This function is experimental and might change in the future.
3897+
3898+ </Tip>
3899+
3900+ Parameters:
3901+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3902+ Can be either:
3903+
3904+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3905+ the Hub.
3906+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3907+ with [`ModelMixin.save_pretrained`].
3908+ - A [torch state
3909+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3910+
3911+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3912+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3913+ is not used.
3914+ force_download (`bool`, *optional*, defaults to `False`):
3915+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3916+ cached versions if they exist.
3917+
3918+ proxies (`Dict[str, str]`, *optional*):
3919+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3920+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3921+ local_files_only (`bool`, *optional*, defaults to `False`):
3922+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3923+ won't be downloaded from the Hub.
3924+ token (`str` or *bool*, *optional*):
3925+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3926+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3927+ revision (`str`, *optional*, defaults to `"main"`):
3928+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3929+ allowed by Git.
3930+ subfolder (`str`, *optional*, defaults to `""`):
3931+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3932+
3933+ """
3934+ # Load the main state dict first which has the LoRA layers for either of
3935+ # transformer and text encoder or both.
3936+ cache_dir = kwargs .pop ("cache_dir" , None )
3937+ force_download = kwargs .pop ("force_download" , False )
3938+ proxies = kwargs .pop ("proxies" , None )
3939+ local_files_only = kwargs .pop ("local_files_only" , None )
3940+ token = kwargs .pop ("token" , None )
3941+ revision = kwargs .pop ("revision" , None )
3942+ subfolder = kwargs .pop ("subfolder" , None )
3943+ weight_name = kwargs .pop ("weight_name" , None )
3944+ use_safetensors = kwargs .pop ("use_safetensors" , None )
3945+
3946+ allow_pickle = False
3947+ if use_safetensors is None :
3948+ use_safetensors = True
3949+ allow_pickle = True
3950+
3951+ user_agent = {
3952+ "file_type" : "attn_procs_weights" ,
3953+ "framework" : "pytorch" ,
3954+ }
3955+
3956+ state_dict = _fetch_state_dict (
3957+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
3958+ weight_name = weight_name ,
3959+ use_safetensors = use_safetensors ,
3960+ local_files_only = local_files_only ,
3961+ cache_dir = cache_dir ,
3962+ force_download = force_download ,
3963+ proxies = proxies ,
3964+ token = token ,
3965+ revision = revision ,
3966+ subfolder = subfolder ,
3967+ user_agent = user_agent ,
3968+ allow_pickle = allow_pickle ,
3969+ )
3970+
3971+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
3972+ if is_dora_scale_present :
3973+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3974+ logger .warning (warn_msg )
3975+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
3976+
3977+ return state_dict
3978+
3979+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3980+ def load_lora_weights (
3981+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
3982+ ):
3983+ """
3984+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3985+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3986+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3987+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3988+ dict is loaded into `self.transformer`.
3989+
3990+ Parameters:
3991+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3992+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3993+ adapter_name (`str`, *optional*):
3994+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3995+ `default_{i}` where i is the total number of adapters being loaded.
3996+ low_cpu_mem_usage (`bool`, *optional*):
3997+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3998+ weights.
3999+ kwargs (`dict`, *optional*):
4000+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4001+ """
4002+ if not USE_PEFT_BACKEND :
4003+ raise ValueError ("PEFT backend is required for this method." )
4004+
4005+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
4006+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
4007+ raise ValueError (
4008+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4009+ )
4010+
4011+ # if a dict is passed, copy it instead of modifying it inplace
4012+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
4013+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
4014+
4015+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4016+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4017+
4018+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
4019+ if not is_correct_format :
4020+ raise ValueError ("Invalid LoRA checkpoint." )
4021+
4022+ self .load_lora_into_transformer (
4023+ state_dict ,
4024+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
4025+ adapter_name = adapter_name ,
4026+ _pipeline = self ,
4027+ low_cpu_mem_usage = low_cpu_mem_usage ,
4028+ )
4029+
4030+ @classmethod
4031+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
4032+ def load_lora_into_transformer (
4033+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
4034+ ):
4035+ """
4036+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4037+
4038+ Parameters:
4039+ state_dict (`dict`):
4040+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4041+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4042+ encoder lora layers.
4043+ transformer (`HunyuanVideoTransformer3DModel`):
4044+ The Transformer model to load the LoRA layers into.
4045+ adapter_name (`str`, *optional*):
4046+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4047+ `default_{i}` where i is the total number of adapters being loaded.
4048+ low_cpu_mem_usage (`bool`, *optional*):
4049+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4050+ weights.
4051+ """
4052+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
4053+ raise ValueError (
4054+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4055+ )
4056+
4057+ # Load the layers corresponding to transformer.
4058+ logger .info (f"Loading { cls .transformer_name } ." )
4059+ transformer .load_lora_adapter (
4060+ state_dict ,
4061+ network_alphas = None ,
4062+ adapter_name = adapter_name ,
4063+ _pipeline = _pipeline ,
4064+ low_cpu_mem_usage = low_cpu_mem_usage ,
4065+ )
4066+
4067+ @classmethod
4068+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4069+ def save_lora_weights (
4070+ cls ,
4071+ save_directory : Union [str , os .PathLike ],
4072+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
4073+ is_main_process : bool = True ,
4074+ weight_name : str = None ,
4075+ save_function : Callable = None ,
4076+ safe_serialization : bool = True ,
4077+ ):
4078+ r"""
4079+ Save the LoRA parameters corresponding to the UNet and text encoder.
4080+
4081+ Arguments:
4082+ save_directory (`str` or `os.PathLike`):
4083+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4084+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4085+ State dict of the LoRA layers corresponding to the `transformer`.
4086+ is_main_process (`bool`, *optional*, defaults to `True`):
4087+ Whether the process calling this is the main process or not. Useful during distributed training and you
4088+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4089+ process to avoid race conditions.
4090+ save_function (`Callable`):
4091+ The function to use to save the state dictionary. Useful during distributed training when you need to
4092+ replace `torch.save` with another method. Can be configured with the environment variable
4093+ `DIFFUSERS_SAVE_MODE`.
4094+ safe_serialization (`bool`, *optional*, defaults to `True`):
4095+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4096+ """
4097+ state_dict = {}
4098+
4099+ if not transformer_lora_layers :
4100+ raise ValueError ("You must pass `transformer_lora_layers`." )
4101+
4102+ if transformer_lora_layers :
4103+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
4104+
4105+ # Save the model
4106+ cls .write_lora_layers (
4107+ state_dict = state_dict ,
4108+ save_directory = save_directory ,
4109+ is_main_process = is_main_process ,
4110+ weight_name = weight_name ,
4111+ save_function = save_function ,
4112+ safe_serialization = safe_serialization ,
4113+ )
4114+
4115+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
4116+ def fuse_lora (
4117+ self ,
4118+ components : List [str ] = ["transformer" , "text_encoder" ],
4119+ lora_scale : float = 1.0 ,
4120+ safe_fusing : bool = False ,
4121+ adapter_names : Optional [List [str ]] = None ,
4122+ ** kwargs ,
4123+ ):
4124+ r"""
4125+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4126+
4127+ <Tip warning={true}>
4128+
4129+ This is an experimental API.
4130+
4131+ </Tip>
4132+
4133+ Args:
4134+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4135+ lora_scale (`float`, defaults to 1.0):
4136+ Controls how much to influence the outputs with the LoRA parameters.
4137+ safe_fusing (`bool`, defaults to `False`):
4138+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4139+ adapter_names (`List[str]`, *optional*):
4140+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4141+
4142+ Example:
4143+
4144+ ```py
4145+ from diffusers import DiffusionPipeline
4146+ import torch
4147+
4148+ pipeline = DiffusionPipeline.from_pretrained(
4149+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4150+ ).to("cuda")
4151+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4152+ pipeline.fuse_lora(lora_scale=0.7)
4153+ ```
4154+ """
4155+ super ().fuse_lora (
4156+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
4157+ )
4158+
4159+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
4160+ def unfuse_lora (self , components : List [str ] = ["transformer" , "text_encoder" ], ** kwargs ):
4161+ r"""
4162+ Reverses the effect of
4163+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4164+
4165+ <Tip warning={true}>
4166+
4167+ This is an experimental API.
4168+
4169+ </Tip>
4170+
4171+ Args:
4172+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4173+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4174+ unfuse_text_encoder (`bool`, defaults to `True`):
4175+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
4176+ LoRA parameters then it won't have any effect.
4177+ """
4178+ super ().unfuse_lora (components = components )
4179+
4180+
38734181class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
38744182 def __init__ (self , * args , ** kwargs ):
38754183 deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments