@@ -3639,6 +3639,291 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
36393639 super ().unfuse_lora (components = components , ** kwargs )
36403640
36413641
3642+ class KandinskyLoraLoaderMixin (LoraBaseMixin ):
3643+ r"""
3644+ Load LoRA layers into [`Kandinsky5Transformer3DModel`],
3645+ """
3646+
3647+ _lora_loadable_modules = ["transformer" ]
3648+ transformer_name = TRANSFORMER_NAME
3649+
3650+ @classmethod
3651+ @validate_hf_hub_args
3652+ def lora_state_dict (
3653+ cls ,
3654+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3655+ ** kwargs ,
3656+ ):
3657+ r"""
3658+ Return state dict for lora weights and the network alphas.
3659+
3660+ Parameters:
3661+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3662+ Can be either:
3663+ - A string, the *model id* of a pretrained model hosted on the Hub.
3664+ - A path to a *directory* containing the model weights.
3665+ - A [torch state
3666+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3667+
3668+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3669+ Path to a directory where a downloaded pretrained model configuration is cached.
3670+ force_download (`bool`, *optional*, defaults to `False`):
3671+ Whether or not to force the (re-)download of the model weights.
3672+ proxies (`Dict[str, str]`, *optional*):
3673+ A dictionary of proxy servers to use by protocol or endpoint.
3674+ local_files_only (`bool`, *optional*, defaults to `False`):
3675+ Whether to only load local model weights and configuration files.
3676+ token (`str` or *bool*, *optional*):
3677+ The token to use as HTTP bearer authorization for remote files.
3678+ revision (`str`, *optional*, defaults to `"main"`):
3679+ The specific model version to use.
3680+ subfolder (`str`, *optional*, defaults to `""`):
3681+ The subfolder location of a model file within a larger model repository.
3682+ weight_name (`str`, *optional*, defaults to None):
3683+ Name of the serialized state dict file.
3684+ use_safetensors (`bool`, *optional*):
3685+ Whether to use safetensors for loading.
3686+ return_lora_metadata (`bool`, *optional*, defaults to False):
3687+ When enabled, additionally return the LoRA adapter metadata.
3688+ """
3689+ # Load the main state dict first which has the LoRA layers
3690+ cache_dir = kwargs .pop ("cache_dir" , None )
3691+ force_download = kwargs .pop ("force_download" , False )
3692+ proxies = kwargs .pop ("proxies" , None )
3693+ local_files_only = kwargs .pop ("local_files_only" , None )
3694+ token = kwargs .pop ("token" , None )
3695+ revision = kwargs .pop ("revision" , None )
3696+ subfolder = kwargs .pop ("subfolder" , None )
3697+ weight_name = kwargs .pop ("weight_name" , None )
3698+ use_safetensors = kwargs .pop ("use_safetensors" , None )
3699+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
3700+
3701+ allow_pickle = False
3702+ if use_safetensors is None :
3703+ use_safetensors = True
3704+ allow_pickle = True
3705+
3706+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
3707+
3708+ state_dict , metadata = _fetch_state_dict (
3709+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
3710+ weight_name = weight_name ,
3711+ use_safetensors = use_safetensors ,
3712+ local_files_only = local_files_only ,
3713+ cache_dir = cache_dir ,
3714+ force_download = force_download ,
3715+ proxies = proxies ,
3716+ token = token ,
3717+ revision = revision ,
3718+ subfolder = subfolder ,
3719+ user_agent = user_agent ,
3720+ allow_pickle = allow_pickle ,
3721+ )
3722+
3723+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
3724+ if is_dora_scale_present :
3725+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3726+ logger .warning (warn_msg )
3727+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
3728+
3729+ out = (state_dict , metadata ) if return_lora_metadata else state_dict
3730+ return out
3731+
3732+ def load_lora_weights (
3733+ self ,
3734+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3735+ adapter_name : Optional [str ] = None ,
3736+ hotswap : bool = False ,
3737+ ** kwargs ,
3738+ ):
3739+ """
3740+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
3741+
3742+ Parameters:
3743+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3744+ See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
3745+ adapter_name (`str`, *optional*):
3746+ Adapter name to be used for referencing the loaded adapter model.
3747+ hotswap (`bool`, *optional*):
3748+ Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
3749+ low_cpu_mem_usage (`bool`, *optional*):
3750+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3751+ weights.
3752+ kwargs (`dict`, *optional*):
3753+ See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
3754+ """
3755+ if not USE_PEFT_BACKEND :
3756+ raise ValueError ("PEFT backend is required for this method." )
3757+
3758+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
3759+ if low_cpu_mem_usage and not is_peft_version (">=" , "0.13.1" ):
3760+ raise ValueError (
3761+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3762+ )
3763+
3764+ # if a dict is passed, copy it instead of modifying it inplace
3765+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
3766+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
3767+
3768+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3769+ kwargs ["return_lora_metadata" ] = True
3770+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3771+
3772+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
3773+ if not is_correct_format :
3774+ raise ValueError ("Invalid LoRA checkpoint." )
3775+
3776+ # Load LoRA into transformer
3777+ self .load_lora_into_transformer (
3778+ state_dict ,
3779+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
3780+ adapter_name = adapter_name ,
3781+ metadata = metadata ,
3782+ _pipeline = self ,
3783+ low_cpu_mem_usage = low_cpu_mem_usage ,
3784+ hotswap = hotswap ,
3785+ )
3786+
3787+ @classmethod
3788+ def load_lora_into_transformer (
3789+ cls ,
3790+ state_dict ,
3791+ transformer ,
3792+ adapter_name = None ,
3793+ _pipeline = None ,
3794+ low_cpu_mem_usage = False ,
3795+ hotswap : bool = False ,
3796+ metadata = None ,
3797+ ):
3798+ """
3799+ Load the LoRA layers specified in `state_dict` into `transformer`.
3800+
3801+ Parameters:
3802+ state_dict (`dict`):
3803+ A standard state dict containing the lora layer parameters.
3804+ transformer (`Kandinsky5Transformer3DModel`):
3805+ The transformer model to load the LoRA layers into.
3806+ adapter_name (`str`, *optional*):
3807+ Adapter name to be used for referencing the loaded adapter model.
3808+ low_cpu_mem_usage (`bool`, *optional*):
3809+ Speed up model loading by only loading the pretrained LoRA weights.
3810+ hotswap (`bool`, *optional*):
3811+ See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
3812+ metadata (`dict`):
3813+ Optional LoRA adapter metadata.
3814+ """
3815+ if low_cpu_mem_usage and not is_peft_version (">=" , "0.13.1" ):
3816+ raise ValueError (
3817+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3818+ )
3819+
3820+ # Load the layers corresponding to transformer.
3821+ logger .info (f"Loading { cls .transformer_name } ." )
3822+ transformer .load_lora_adapter (
3823+ state_dict ,
3824+ network_alphas = None ,
3825+ adapter_name = adapter_name ,
3826+ metadata = metadata ,
3827+ _pipeline = _pipeline ,
3828+ low_cpu_mem_usage = low_cpu_mem_usage ,
3829+ hotswap = hotswap ,
3830+ )
3831+
3832+ @classmethod
3833+ def save_lora_weights (
3834+ cls ,
3835+ save_directory : Union [str , os .PathLike ],
3836+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
3837+ is_main_process : bool = True ,
3838+ weight_name : str = None ,
3839+ save_function : Callable = None ,
3840+ safe_serialization : bool = True ,
3841+ transformer_lora_adapter_metadata = None ,
3842+ ):
3843+ r"""
3844+ Save the LoRA parameters corresponding to the transformer and text encoders.
3845+
3846+ Arguments:
3847+ save_directory (`str` or `os.PathLike`):
3848+ Directory to save LoRA parameters to.
3849+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3850+ State dict of the LoRA layers corresponding to the `transformer`.
3851+ is_main_process (`bool`, *optional*, defaults to `True`):
3852+ Whether the process calling this is the main process.
3853+ save_function (`Callable`):
3854+ The function to use to save the state dictionary.
3855+ safe_serialization (`bool`, *optional*, defaults to `True`):
3856+ Whether to save the model using `safetensors` or the traditional PyTorch way.
3857+ transformer_lora_adapter_metadata:
3858+ LoRA adapter metadata associated with the transformer.
3859+ """
3860+ lora_layers = {}
3861+ lora_metadata = {}
3862+
3863+ if transformer_lora_layers :
3864+ lora_layers [cls .transformer_name ] = transformer_lora_layers
3865+ lora_metadata [cls .transformer_name ] = transformer_lora_adapter_metadata
3866+
3867+ if not lora_layers :
3868+ raise ValueError ("You must pass at least one of `transformer_lora_layers`" )
3869+
3870+ cls ._save_lora_weights (
3871+ save_directory = save_directory ,
3872+ lora_layers = lora_layers ,
3873+ lora_metadata = lora_metadata ,
3874+ is_main_process = is_main_process ,
3875+ weight_name = weight_name ,
3876+ save_function = save_function ,
3877+ safe_serialization = safe_serialization ,
3878+ )
3879+
3880+ def fuse_lora (
3881+ self ,
3882+ components : List [str ] = ["transformer" ],
3883+ lora_scale : float = 1.0 ,
3884+ safe_fusing : bool = False ,
3885+ adapter_names : Optional [List [str ]] = None ,
3886+ ** kwargs ,
3887+ ):
3888+ r"""
3889+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3890+
3891+ Args:
3892+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3893+ lora_scale (`float`, defaults to 1.0):
3894+ Controls how much to influence the outputs with the LoRA parameters.
3895+ safe_fusing (`bool`, defaults to `False`):
3896+ Whether to check fused weights for NaN values before fusing.
3897+ adapter_names (`List[str]`, *optional*):
3898+ Adapter names to be used for fusing.
3899+
3900+ Example:
3901+ ```py
3902+ from diffusers import Kandinsky5T2VPipeline
3903+
3904+ pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
3905+ pipeline.load_lora_weights("path/to/lora.safetensors")
3906+ pipeline.fuse_lora(lora_scale=0.7)
3907+ ```
3908+ """
3909+ super ().fuse_lora (
3910+ components = components ,
3911+ lora_scale = lora_scale ,
3912+ safe_fusing = safe_fusing ,
3913+ adapter_names = adapter_names ,
3914+ ** kwargs ,
3915+ )
3916+
3917+ def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
3918+ r"""
3919+ Reverses the effect of [`pipe.fuse_lora()`].
3920+
3921+ Args:
3922+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3923+ """
3924+ super ().unfuse_lora (components = components , ** kwargs )
3925+
3926+
36423927class WanLoraLoaderMixin (LoraBaseMixin ):
36433928 r"""
36443929 Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
0 commit comments