@@ -3805,6 +3805,315 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
38053805 super ().unfuse_lora (components = components )
38063806
38073807
3808+ class Lumina2LoraLoaderMixin (LoraBaseMixin ):
3809+ r"""
3810+ Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
3811+ """
3812+
3813+ _lora_loadable_modules = ["transformer" ]
3814+ transformer_name = TRANSFORMER_NAME
3815+
3816+ @classmethod
3817+ @validate_hf_hub_args
3818+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
3819+ def lora_state_dict (
3820+ cls ,
3821+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3822+ ** kwargs ,
3823+ ):
3824+ r"""
3825+ Return state dict for lora weights and the network alphas.
3826+
3827+ <Tip warning={true}>
3828+
3829+ We support loading original format HunyuanVideo LoRA checkpoints.
3830+
3831+ This function is experimental and might change in the future.
3832+
3833+ </Tip>
3834+
3835+ Parameters:
3836+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3837+ Can be either:
3838+
3839+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3840+ the Hub.
3841+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3842+ with [`ModelMixin.save_pretrained`].
3843+ - A [torch state
3844+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3845+
3846+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3847+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3848+ is not used.
3849+ force_download (`bool`, *optional*, defaults to `False`):
3850+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3851+ cached versions if they exist.
3852+
3853+ proxies (`Dict[str, str]`, *optional*):
3854+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3855+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3856+ local_files_only (`bool`, *optional*, defaults to `False`):
3857+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3858+ won't be downloaded from the Hub.
3859+ token (`str` or *bool*, *optional*):
3860+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3861+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3862+ revision (`str`, *optional*, defaults to `"main"`):
3863+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3864+ allowed by Git.
3865+ subfolder (`str`, *optional*, defaults to `""`):
3866+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3867+
3868+ """
3869+ # Load the main state dict first which has the LoRA layers for either of
3870+ # transformer and text encoder or both.
3871+ cache_dir = kwargs .pop ("cache_dir" , None )
3872+ force_download = kwargs .pop ("force_download" , False )
3873+ proxies = kwargs .pop ("proxies" , None )
3874+ local_files_only = kwargs .pop ("local_files_only" , None )
3875+ token = kwargs .pop ("token" , None )
3876+ revision = kwargs .pop ("revision" , None )
3877+ subfolder = kwargs .pop ("subfolder" , None )
3878+ weight_name = kwargs .pop ("weight_name" , None )
3879+ use_safetensors = kwargs .pop ("use_safetensors" , None )
3880+
3881+ allow_pickle = False
3882+ if use_safetensors is None :
3883+ use_safetensors = True
3884+ allow_pickle = True
3885+
3886+ user_agent = {
3887+ "file_type" : "attn_procs_weights" ,
3888+ "framework" : "pytorch" ,
3889+ }
3890+
3891+ state_dict = _fetch_state_dict (
3892+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
3893+ weight_name = weight_name ,
3894+ use_safetensors = use_safetensors ,
3895+ local_files_only = local_files_only ,
3896+ cache_dir = cache_dir ,
3897+ force_download = force_download ,
3898+ proxies = proxies ,
3899+ token = token ,
3900+ revision = revision ,
3901+ subfolder = subfolder ,
3902+ user_agent = user_agent ,
3903+ allow_pickle = allow_pickle ,
3904+ )
3905+
3906+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
3907+ if is_dora_scale_present :
3908+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3909+ logger .warning (warn_msg )
3910+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
3911+
3912+ is_original_hunyuan_video = any ("img_attn_qkv" in k for k in state_dict )
3913+ if is_original_hunyuan_video :
3914+ state_dict = _convert_hunyuan_video_lora_to_diffusers (state_dict )
3915+
3916+ return state_dict
3917+
3918+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3919+ def load_lora_weights (
3920+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
3921+ ):
3922+ """
3923+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3924+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3925+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3926+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3927+ dict is loaded into `self.transformer`.
3928+
3929+ Parameters:
3930+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3931+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3932+ adapter_name (`str`, *optional*):
3933+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3934+ `default_{i}` where i is the total number of adapters being loaded.
3935+ low_cpu_mem_usage (`bool`, *optional*):
3936+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3937+ weights.
3938+ kwargs (`dict`, *optional*):
3939+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3940+ """
3941+ if not USE_PEFT_BACKEND :
3942+ raise ValueError ("PEFT backend is required for this method." )
3943+
3944+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
3945+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
3946+ raise ValueError (
3947+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3948+ )
3949+
3950+ # if a dict is passed, copy it instead of modifying it inplace
3951+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
3952+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
3953+
3954+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3955+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3956+
3957+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
3958+ if not is_correct_format :
3959+ raise ValueError ("Invalid LoRA checkpoint." )
3960+
3961+ self .load_lora_into_transformer (
3962+ state_dict ,
3963+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
3964+ adapter_name = adapter_name ,
3965+ _pipeline = self ,
3966+ low_cpu_mem_usage = low_cpu_mem_usage ,
3967+ )
3968+
3969+ @classmethod
3970+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
3971+ def load_lora_into_transformer (
3972+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
3973+ ):
3974+ """
3975+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3976+
3977+ Parameters:
3978+ state_dict (`dict`):
3979+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3980+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3981+ encoder lora layers.
3982+ transformer (`HunyuanVideoTransformer3DModel`):
3983+ The Transformer model to load the LoRA layers into.
3984+ adapter_name (`str`, *optional*):
3985+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3986+ `default_{i}` where i is the total number of adapters being loaded.
3987+ low_cpu_mem_usage (`bool`, *optional*):
3988+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3989+ weights.
3990+ """
3991+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
3992+ raise ValueError (
3993+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3994+ )
3995+
3996+ # Load the layers corresponding to transformer.
3997+ logger .info (f"Loading { cls .transformer_name } ." )
3998+ transformer .load_lora_adapter (
3999+ state_dict ,
4000+ network_alphas = None ,
4001+ adapter_name = adapter_name ,
4002+ _pipeline = _pipeline ,
4003+ low_cpu_mem_usage = low_cpu_mem_usage ,
4004+ )
4005+
4006+ @classmethod
4007+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4008+ def save_lora_weights (
4009+ cls ,
4010+ save_directory : Union [str , os .PathLike ],
4011+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
4012+ is_main_process : bool = True ,
4013+ weight_name : str = None ,
4014+ save_function : Callable = None ,
4015+ safe_serialization : bool = True ,
4016+ ):
4017+ r"""
4018+ Save the LoRA parameters corresponding to the UNet and text encoder.
4019+
4020+ Arguments:
4021+ save_directory (`str` or `os.PathLike`):
4022+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4023+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4024+ State dict of the LoRA layers corresponding to the `transformer`.
4025+ is_main_process (`bool`, *optional*, defaults to `True`):
4026+ Whether the process calling this is the main process or not. Useful during distributed training and you
4027+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4028+ process to avoid race conditions.
4029+ save_function (`Callable`):
4030+ The function to use to save the state dictionary. Useful during distributed training when you need to
4031+ replace `torch.save` with another method. Can be configured with the environment variable
4032+ `DIFFUSERS_SAVE_MODE`.
4033+ safe_serialization (`bool`, *optional*, defaults to `True`):
4034+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4035+ """
4036+ state_dict = {}
4037+
4038+ if not transformer_lora_layers :
4039+ raise ValueError ("You must pass `transformer_lora_layers`." )
4040+
4041+ if transformer_lora_layers :
4042+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
4043+
4044+ # Save the model
4045+ cls .write_lora_layers (
4046+ state_dict = state_dict ,
4047+ save_directory = save_directory ,
4048+ is_main_process = is_main_process ,
4049+ weight_name = weight_name ,
4050+ save_function = save_function ,
4051+ safe_serialization = safe_serialization ,
4052+ )
4053+
4054+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
4055+ def fuse_lora (
4056+ self ,
4057+ components : List [str ] = ["transformer" ],
4058+ lora_scale : float = 1.0 ,
4059+ safe_fusing : bool = False ,
4060+ adapter_names : Optional [List [str ]] = None ,
4061+ ** kwargs ,
4062+ ):
4063+ r"""
4064+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4065+
4066+ <Tip warning={true}>
4067+
4068+ This is an experimental API.
4069+
4070+ </Tip>
4071+
4072+ Args:
4073+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4074+ lora_scale (`float`, defaults to 1.0):
4075+ Controls how much to influence the outputs with the LoRA parameters.
4076+ safe_fusing (`bool`, defaults to `False`):
4077+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4078+ adapter_names (`List[str]`, *optional*):
4079+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4080+
4081+ Example:
4082+
4083+ ```py
4084+ from diffusers import DiffusionPipeline
4085+ import torch
4086+
4087+ pipeline = DiffusionPipeline.from_pretrained(
4088+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4089+ ).to("cuda")
4090+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4091+ pipeline.fuse_lora(lora_scale=0.7)
4092+ ```
4093+ """
4094+ super ().fuse_lora (
4095+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
4096+ )
4097+
4098+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
4099+ def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
4100+ r"""
4101+ Reverses the effect of
4102+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4103+
4104+ <Tip warning={true}>
4105+
4106+ This is an experimental API.
4107+
4108+ </Tip>
4109+
4110+ Args:
4111+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4112+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4113+ """
4114+ super ().unfuse_lora (components = components )
4115+
4116+
38084117class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
38094118 def __init__ (self , * args , ** kwargs ):
38104119 deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments