@@ -3639,6 +3639,291 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
36393639        super ().unfuse_lora (components = components , ** kwargs )
36403640
36413641
3642+ class  KandinskyLoraLoaderMixin (LoraBaseMixin ):
3643+     r""" 
3644+     Load LoRA layers into [`Kandinsky5Transformer3DModel`], 
3645+     """ 
3646+ 
3647+     _lora_loadable_modules  =  ["transformer" ]
3648+     transformer_name  =  TRANSFORMER_NAME 
3649+ 
3650+     @classmethod  
3651+     @validate_hf_hub_args  
3652+     def  lora_state_dict (
3653+         cls ,
3654+         pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3655+         ** kwargs ,
3656+     ):
3657+         r""" 
3658+         Return state dict for lora weights and the network alphas. 
3659+ 
3660+         Parameters: 
3661+             pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): 
3662+                 Can be either: 
3663+                     - A string, the *model id* of a pretrained model hosted on the Hub. 
3664+                     - A path to a *directory* containing the model weights. 
3665+                     - A [torch state 
3666+                       dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). 
3667+ 
3668+             cache_dir (`Union[str, os.PathLike]`, *optional*): 
3669+                 Path to a directory where a downloaded pretrained model configuration is cached. 
3670+             force_download (`bool`, *optional*, defaults to `False`): 
3671+                 Whether or not to force the (re-)download of the model weights. 
3672+             proxies (`Dict[str, str]`, *optional*): 
3673+                 A dictionary of proxy servers to use by protocol or endpoint. 
3674+             local_files_only (`bool`, *optional*, defaults to `False`): 
3675+                 Whether to only load local model weights and configuration files. 
3676+             token (`str` or *bool*, *optional*): 
3677+                 The token to use as HTTP bearer authorization for remote files. 
3678+             revision (`str`, *optional*, defaults to `"main"`): 
3679+                 The specific model version to use. 
3680+             subfolder (`str`, *optional*, defaults to `""`): 
3681+                 The subfolder location of a model file within a larger model repository. 
3682+             weight_name (`str`, *optional*, defaults to None): 
3683+                 Name of the serialized state dict file. 
3684+             use_safetensors (`bool`, *optional*): 
3685+                 Whether to use safetensors for loading. 
3686+             return_lora_metadata (`bool`, *optional*, defaults to False): 
3687+                 When enabled, additionally return the LoRA adapter metadata. 
3688+         """ 
3689+         # Load the main state dict first which has the LoRA layers 
3690+         cache_dir  =  kwargs .pop ("cache_dir" , None )
3691+         force_download  =  kwargs .pop ("force_download" , False )
3692+         proxies  =  kwargs .pop ("proxies" , None )
3693+         local_files_only  =  kwargs .pop ("local_files_only" , None )
3694+         token  =  kwargs .pop ("token" , None )
3695+         revision  =  kwargs .pop ("revision" , None )
3696+         subfolder  =  kwargs .pop ("subfolder" , None )
3697+         weight_name  =  kwargs .pop ("weight_name" , None )
3698+         use_safetensors  =  kwargs .pop ("use_safetensors" , None )
3699+         return_lora_metadata  =  kwargs .pop ("return_lora_metadata" , False )
3700+ 
3701+         allow_pickle  =  False 
3702+         if  use_safetensors  is  None :
3703+             use_safetensors  =  True 
3704+             allow_pickle  =  True 
3705+ 
3706+         user_agent  =  {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
3707+ 
3708+         state_dict , metadata  =  _fetch_state_dict (
3709+             pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
3710+             weight_name = weight_name ,
3711+             use_safetensors = use_safetensors ,
3712+             local_files_only = local_files_only ,
3713+             cache_dir = cache_dir ,
3714+             force_download = force_download ,
3715+             proxies = proxies ,
3716+             token = token ,
3717+             revision = revision ,
3718+             subfolder = subfolder ,
3719+             user_agent = user_agent ,
3720+             allow_pickle = allow_pickle ,
3721+         )
3722+ 
3723+         is_dora_scale_present  =  any ("dora_scale"  in  k  for  k  in  state_dict )
3724+         if  is_dora_scale_present :
3725+             warn_msg  =  "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." 
3726+             logger .warning (warn_msg )
3727+             state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "dora_scale"  not  in k }
3728+ 
3729+         out  =  (state_dict , metadata ) if  return_lora_metadata  else  state_dict 
3730+         return  out 
3731+ 
3732+     def  load_lora_weights (
3733+         self ,
3734+         pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3735+         adapter_name : Optional [str ] =  None ,
3736+         hotswap : bool  =  False ,
3737+         ** kwargs ,
3738+     ):
3739+         """ 
3740+         Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` 
3741+ 
3742+         Parameters: 
3743+             pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): 
3744+                 See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. 
3745+             adapter_name (`str`, *optional*): 
3746+                 Adapter name to be used for referencing the loaded adapter model. 
3747+             hotswap (`bool`, *optional*): 
3748+                 Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. 
3749+             low_cpu_mem_usage (`bool`, *optional*): 
3750+                 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random 
3751+                 weights. 
3752+             kwargs (`dict`, *optional*): 
3753+                 See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. 
3754+         """ 
3755+         if  not  USE_PEFT_BACKEND :
3756+             raise  ValueError ("PEFT backend is required for this method." )
3757+ 
3758+         low_cpu_mem_usage  =  kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
3759+         if  low_cpu_mem_usage  and  not  is_peft_version (">=" , "0.13.1" ):
3760+             raise  ValueError (
3761+                 "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." 
3762+             )
3763+ 
3764+         # if a dict is passed, copy it instead of modifying it inplace 
3765+         if  isinstance (pretrained_model_name_or_path_or_dict , dict ):
3766+             pretrained_model_name_or_path_or_dict  =  pretrained_model_name_or_path_or_dict .copy ()
3767+ 
3768+         # First, ensure that the checkpoint is a compatible one and can be successfully loaded. 
3769+         kwargs ["return_lora_metadata" ] =  True 
3770+         state_dict , metadata  =  self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3771+ 
3772+         is_correct_format  =  all ("lora"  in  key  for  key  in  state_dict .keys ())
3773+         if  not  is_correct_format :
3774+             raise  ValueError ("Invalid LoRA checkpoint." )
3775+ 
3776+         # Load LoRA into transformer 
3777+         self .load_lora_into_transformer (
3778+             state_dict ,
3779+             transformer = getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer ,
3780+             adapter_name = adapter_name ,
3781+             metadata = metadata ,
3782+             _pipeline = self ,
3783+             low_cpu_mem_usage = low_cpu_mem_usage ,
3784+             hotswap = hotswap ,
3785+         )
3786+ 
3787+     @classmethod  
3788+     def  load_lora_into_transformer (
3789+         cls ,
3790+         state_dict ,
3791+         transformer ,
3792+         adapter_name = None ,
3793+         _pipeline = None ,
3794+         low_cpu_mem_usage = False ,
3795+         hotswap : bool  =  False ,
3796+         metadata = None ,
3797+     ):
3798+         """ 
3799+         Load the LoRA layers specified in `state_dict` into `transformer`. 
3800+ 
3801+         Parameters: 
3802+             state_dict (`dict`): 
3803+                 A standard state dict containing the lora layer parameters. 
3804+             transformer (`Kandinsky5Transformer3DModel`): 
3805+                 The transformer model to load the LoRA layers into. 
3806+             adapter_name (`str`, *optional*): 
3807+                 Adapter name to be used for referencing the loaded adapter model. 
3808+             low_cpu_mem_usage (`bool`, *optional*): 
3809+                 Speed up model loading by only loading the pretrained LoRA weights. 
3810+             hotswap (`bool`, *optional*): 
3811+                 See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`]. 
3812+             metadata (`dict`): 
3813+                 Optional LoRA adapter metadata. 
3814+         """ 
3815+         if  low_cpu_mem_usage  and  not  is_peft_version (">=" , "0.13.1" ):
3816+             raise  ValueError (
3817+                 "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." 
3818+             )
3819+ 
3820+         # Load the layers corresponding to transformer. 
3821+         logger .info (f"Loading { cls .transformer_name }  )
3822+         transformer .load_lora_adapter (
3823+             state_dict ,
3824+             network_alphas = None ,
3825+             adapter_name = adapter_name ,
3826+             metadata = metadata ,
3827+             _pipeline = _pipeline ,
3828+             low_cpu_mem_usage = low_cpu_mem_usage ,
3829+             hotswap = hotswap ,
3830+         )
3831+ 
3832+     @classmethod  
3833+     def  save_lora_weights (
3834+         cls ,
3835+         save_directory : Union [str , os .PathLike ],
3836+         transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] =  None ,
3837+         is_main_process : bool  =  True ,
3838+         weight_name : str  =  None ,
3839+         save_function : Callable  =  None ,
3840+         safe_serialization : bool  =  True ,
3841+         transformer_lora_adapter_metadata = None ,
3842+     ):
3843+         r""" 
3844+         Save the LoRA parameters corresponding to the transformer and text encoders. 
3845+ 
3846+         Arguments: 
3847+             save_directory (`str` or `os.PathLike`): 
3848+                 Directory to save LoRA parameters to. 
3849+             transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): 
3850+                 State dict of the LoRA layers corresponding to the `transformer`. 
3851+             is_main_process (`bool`, *optional*, defaults to `True`): 
3852+                 Whether the process calling this is the main process. 
3853+             save_function (`Callable`): 
3854+                 The function to use to save the state dictionary. 
3855+             safe_serialization (`bool`, *optional*, defaults to `True`): 
3856+                 Whether to save the model using `safetensors` or the traditional PyTorch way. 
3857+             transformer_lora_adapter_metadata: 
3858+                 LoRA adapter metadata associated with the transformer. 
3859+         """ 
3860+         lora_layers  =  {}
3861+         lora_metadata  =  {}
3862+ 
3863+         if  transformer_lora_layers :
3864+             lora_layers [cls .transformer_name ] =  transformer_lora_layers 
3865+             lora_metadata [cls .transformer_name ] =  transformer_lora_adapter_metadata 
3866+ 
3867+         if  not  lora_layers :
3868+             raise  ValueError ("You must pass at least one of `transformer_lora_layers`" )
3869+ 
3870+         cls ._save_lora_weights (
3871+             save_directory = save_directory ,
3872+             lora_layers = lora_layers ,
3873+             lora_metadata = lora_metadata ,
3874+             is_main_process = is_main_process ,
3875+             weight_name = weight_name ,
3876+             save_function = save_function ,
3877+             safe_serialization = safe_serialization ,
3878+         )
3879+ 
3880+     def  fuse_lora (
3881+         self ,
3882+         components : List [str ] =  ["transformer" ],
3883+         lora_scale : float  =  1.0 ,
3884+         safe_fusing : bool  =  False ,
3885+         adapter_names : Optional [List [str ]] =  None ,
3886+         ** kwargs ,
3887+     ):
3888+         r""" 
3889+         Fuses the LoRA parameters into the original parameters of the corresponding blocks. 
3890+ 
3891+         Args: 
3892+             components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. 
3893+             lora_scale (`float`, defaults to 1.0): 
3894+                 Controls how much to influence the outputs with the LoRA parameters. 
3895+             safe_fusing (`bool`, defaults to `False`): 
3896+                 Whether to check fused weights for NaN values before fusing. 
3897+             adapter_names (`List[str]`, *optional*): 
3898+                 Adapter names to be used for fusing. 
3899+ 
3900+         Example: 
3901+         ```py 
3902+         from diffusers import Kandinsky5T2VPipeline 
3903+ 
3904+         pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") 
3905+         pipeline.load_lora_weights("path/to/lora.safetensors") 
3906+         pipeline.fuse_lora(lora_scale=0.7) 
3907+         ``` 
3908+         """ 
3909+         super ().fuse_lora (
3910+             components = components ,
3911+             lora_scale = lora_scale ,
3912+             safe_fusing = safe_fusing ,
3913+             adapter_names = adapter_names ,
3914+             ** kwargs ,
3915+         )
3916+ 
3917+     def  unfuse_lora (self , components : List [str ] =  ["transformer" ], ** kwargs ):
3918+         r""" 
3919+         Reverses the effect of [`pipe.fuse_lora()`]. 
3920+ 
3921+         Args: 
3922+             components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. 
3923+         """ 
3924+         super ().unfuse_lora (components = components , ** kwargs )
3925+ 
3926+ 
36423927class  WanLoraLoaderMixin (LoraBaseMixin ):
36433928    r""" 
36443929    Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. 
0 commit comments