@@ -2364,7 +2364,7 @@ def save_lora_weights(
23642364
23652365class  CogVideoXLoraLoaderMixin (LoraBaseMixin ):
23662366    r""" 
2367-     Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX `]. 
2367+     Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline `]. 
23682368    """ 
23692369
23702370    _lora_loadable_modules  =  ["transformer" ]
@@ -2669,6 +2669,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
26692669        super ().unfuse_lora (components = components )
26702670
26712671
2672+ class  Mochi1LoraLoaderMixin (LoraBaseMixin ):
2673+     r""" 
2674+     Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`]. 
2675+     """ 
2676+ 
2677+     _lora_loadable_modules  =  ["transformer" ]
2678+     transformer_name  =  TRANSFORMER_NAME 
2679+ 
2680+     @classmethod  
2681+     @validate_hf_hub_args  
2682+     # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict 
2683+     def  lora_state_dict (
2684+         cls ,
2685+         pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
2686+         ** kwargs ,
2687+     ):
2688+         r""" 
2689+         Return state dict for lora weights and the network alphas. 
2690+ 
2691+         <Tip warning={true}> 
2692+ 
2693+         We support loading A1111 formatted LoRA checkpoints in a limited capacity. 
2694+ 
2695+         This function is experimental and might change in the future. 
2696+ 
2697+         </Tip> 
2698+ 
2699+         Parameters: 
2700+             pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): 
2701+                 Can be either: 
2702+ 
2703+                     - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on 
2704+                       the Hub. 
2705+                     - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved 
2706+                       with [`ModelMixin.save_pretrained`]. 
2707+                     - A [torch state 
2708+                       dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). 
2709+ 
2710+             cache_dir (`Union[str, os.PathLike]`, *optional*): 
2711+                 Path to a directory where a downloaded pretrained model configuration is cached if the standard cache 
2712+                 is not used. 
2713+             force_download (`bool`, *optional*, defaults to `False`): 
2714+                 Whether or not to force the (re-)download of the model weights and configuration files, overriding the 
2715+                 cached versions if they exist. 
2716+ 
2717+             proxies (`Dict[str, str]`, *optional*): 
2718+                 A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 
2719+                 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 
2720+             local_files_only (`bool`, *optional*, defaults to `False`): 
2721+                 Whether to only load local model weights and configuration files or not. If set to `True`, the model 
2722+                 won't be downloaded from the Hub. 
2723+             token (`str` or *bool*, *optional*): 
2724+                 The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from 
2725+                 `diffusers-cli login` (stored in `~/.huggingface`) is used. 
2726+             revision (`str`, *optional*, defaults to `"main"`): 
2727+                 The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier 
2728+                 allowed by Git. 
2729+             subfolder (`str`, *optional*, defaults to `""`): 
2730+                 The subfolder location of a model file within a larger model repository on the Hub or locally. 
2731+ 
2732+         """ 
2733+         # Load the main state dict first which has the LoRA layers for either of 
2734+         # transformer and text encoder or both. 
2735+         cache_dir  =  kwargs .pop ("cache_dir" , None )
2736+         force_download  =  kwargs .pop ("force_download" , False )
2737+         proxies  =  kwargs .pop ("proxies" , None )
2738+         local_files_only  =  kwargs .pop ("local_files_only" , None )
2739+         token  =  kwargs .pop ("token" , None )
2740+         revision  =  kwargs .pop ("revision" , None )
2741+         subfolder  =  kwargs .pop ("subfolder" , None )
2742+         weight_name  =  kwargs .pop ("weight_name" , None )
2743+         use_safetensors  =  kwargs .pop ("use_safetensors" , None )
2744+ 
2745+         allow_pickle  =  False 
2746+         if  use_safetensors  is  None :
2747+             use_safetensors  =  True 
2748+             allow_pickle  =  True 
2749+ 
2750+         user_agent  =  {
2751+             "file_type" : "attn_procs_weights" ,
2752+             "framework" : "pytorch" ,
2753+         }
2754+ 
2755+         state_dict  =  _fetch_state_dict (
2756+             pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
2757+             weight_name = weight_name ,
2758+             use_safetensors = use_safetensors ,
2759+             local_files_only = local_files_only ,
2760+             cache_dir = cache_dir ,
2761+             force_download = force_download ,
2762+             proxies = proxies ,
2763+             token = token ,
2764+             revision = revision ,
2765+             subfolder = subfolder ,
2766+             user_agent = user_agent ,
2767+             allow_pickle = allow_pickle ,
2768+         )
2769+ 
2770+         is_dora_scale_present  =  any ("dora_scale"  in  k  for  k  in  state_dict )
2771+         if  is_dora_scale_present :
2772+             warn_msg  =  "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." 
2773+             logger .warning (warn_msg )
2774+             state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "dora_scale"  not  in k }
2775+ 
2776+         return  state_dict 
2777+ 
2778+     # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights 
2779+     def  load_lora_weights (
2780+         self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs 
2781+     ):
2782+         """ 
2783+         Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and 
2784+         `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See 
2785+         [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. 
2786+         See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state 
2787+         dict is loaded into `self.transformer`. 
2788+ 
2789+         Parameters: 
2790+             pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): 
2791+                 See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. 
2792+             adapter_name (`str`, *optional*): 
2793+                 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use 
2794+                 `default_{i}` where i is the total number of adapters being loaded. 
2795+             low_cpu_mem_usage (`bool`, *optional*): 
2796+                 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random 
2797+                 weights. 
2798+             kwargs (`dict`, *optional*): 
2799+                 See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. 
2800+         """ 
2801+         if  not  USE_PEFT_BACKEND :
2802+             raise  ValueError ("PEFT backend is required for this method." )
2803+ 
2804+         low_cpu_mem_usage  =  kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
2805+         if  low_cpu_mem_usage  and  is_peft_version ("<" , "0.13.0" ):
2806+             raise  ValueError (
2807+                 "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." 
2808+             )
2809+ 
2810+         # if a dict is passed, copy it instead of modifying it inplace 
2811+         if  isinstance (pretrained_model_name_or_path_or_dict , dict ):
2812+             pretrained_model_name_or_path_or_dict  =  pretrained_model_name_or_path_or_dict .copy ()
2813+ 
2814+         # First, ensure that the checkpoint is a compatible one and can be successfully loaded. 
2815+         state_dict  =  self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
2816+ 
2817+         is_correct_format  =  all ("lora"  in  key  for  key  in  state_dict .keys ())
2818+         if  not  is_correct_format :
2819+             raise  ValueError ("Invalid LoRA checkpoint." )
2820+ 
2821+         self .load_lora_into_transformer (
2822+             state_dict ,
2823+             transformer = getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer ,
2824+             adapter_name = adapter_name ,
2825+             _pipeline = self ,
2826+             low_cpu_mem_usage = low_cpu_mem_usage ,
2827+         )
2828+ 
2829+     @classmethod  
2830+     # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel 
2831+     def  load_lora_into_transformer (
2832+         cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False 
2833+     ):
2834+         """ 
2835+         This will load the LoRA layers specified in `state_dict` into `transformer`. 
2836+ 
2837+         Parameters: 
2838+             state_dict (`dict`): 
2839+                 A standard state dict containing the lora layer parameters. The keys can either be indexed directly 
2840+                 into the unet or prefixed with an additional `unet` which can be used to distinguish between text 
2841+                 encoder lora layers. 
2842+             transformer (`CogVideoXTransformer3DModel`): 
2843+                 The Transformer model to load the LoRA layers into. 
2844+             adapter_name (`str`, *optional*): 
2845+                 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use 
2846+                 `default_{i}` where i is the total number of adapters being loaded. 
2847+             low_cpu_mem_usage (`bool`, *optional*): 
2848+                 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random 
2849+                 weights. 
2850+         """ 
2851+         if  low_cpu_mem_usage  and  is_peft_version ("<" , "0.13.0" ):
2852+             raise  ValueError (
2853+                 "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." 
2854+             )
2855+ 
2856+         # Load the layers corresponding to transformer. 
2857+         logger .info (f"Loading { cls .transformer_name }  )
2858+         transformer .load_lora_adapter (
2859+             state_dict ,
2860+             network_alphas = None ,
2861+             adapter_name = adapter_name ,
2862+             _pipeline = _pipeline ,
2863+             low_cpu_mem_usage = low_cpu_mem_usage ,
2864+         )
2865+ 
2866+     @classmethod  
2867+     # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights 
2868+     def  save_lora_weights (
2869+         cls ,
2870+         save_directory : Union [str , os .PathLike ],
2871+         transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] =  None ,
2872+         is_main_process : bool  =  True ,
2873+         weight_name : str  =  None ,
2874+         save_function : Callable  =  None ,
2875+         safe_serialization : bool  =  True ,
2876+     ):
2877+         r""" 
2878+         Save the LoRA parameters corresponding to the UNet and text encoder. 
2879+ 
2880+         Arguments: 
2881+             save_directory (`str` or `os.PathLike`): 
2882+                 Directory to save LoRA parameters to. Will be created if it doesn't exist. 
2883+             transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): 
2884+                 State dict of the LoRA layers corresponding to the `transformer`. 
2885+             is_main_process (`bool`, *optional*, defaults to `True`): 
2886+                 Whether the process calling this is the main process or not. Useful during distributed training and you 
2887+                 need to call this function on all processes. In this case, set `is_main_process=True` only on the main 
2888+                 process to avoid race conditions. 
2889+             save_function (`Callable`): 
2890+                 The function to use to save the state dictionary. Useful during distributed training when you need to 
2891+                 replace `torch.save` with another method. Can be configured with the environment variable 
2892+                 `DIFFUSERS_SAVE_MODE`. 
2893+             safe_serialization (`bool`, *optional*, defaults to `True`): 
2894+                 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. 
2895+         """ 
2896+         state_dict  =  {}
2897+ 
2898+         if  not  transformer_lora_layers :
2899+             raise  ValueError ("You must pass `transformer_lora_layers`." )
2900+ 
2901+         if  transformer_lora_layers :
2902+             state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
2903+ 
2904+         # Save the model 
2905+         cls .write_lora_layers (
2906+             state_dict = state_dict ,
2907+             save_directory = save_directory ,
2908+             is_main_process = is_main_process ,
2909+             weight_name = weight_name ,
2910+             save_function = save_function ,
2911+             safe_serialization = safe_serialization ,
2912+         )
2913+ 
2914+     # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer 
2915+     def  fuse_lora (
2916+         self ,
2917+         components : List [str ] =  ["transformer" , "text_encoder" ],
2918+         lora_scale : float  =  1.0 ,
2919+         safe_fusing : bool  =  False ,
2920+         adapter_names : Optional [List [str ]] =  None ,
2921+         ** kwargs ,
2922+     ):
2923+         r""" 
2924+         Fuses the LoRA parameters into the original parameters of the corresponding blocks. 
2925+ 
2926+         <Tip warning={true}> 
2927+ 
2928+         This is an experimental API. 
2929+ 
2930+         </Tip> 
2931+ 
2932+         Args: 
2933+             components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. 
2934+             lora_scale (`float`, defaults to 1.0): 
2935+                 Controls how much to influence the outputs with the LoRA parameters. 
2936+             safe_fusing (`bool`, defaults to `False`): 
2937+                 Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. 
2938+             adapter_names (`List[str]`, *optional*): 
2939+                 Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. 
2940+ 
2941+         Example: 
2942+ 
2943+         ```py 
2944+         from diffusers import DiffusionPipeline 
2945+         import torch 
2946+ 
2947+         pipeline = DiffusionPipeline.from_pretrained( 
2948+             "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 
2949+         ).to("cuda") 
2950+         pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") 
2951+         pipeline.fuse_lora(lora_scale=0.7) 
2952+         ``` 
2953+         """ 
2954+         super ().fuse_lora (
2955+             components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names 
2956+         )
2957+ 
2958+     # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer 
2959+     def  unfuse_lora (self , components : List [str ] =  ["transformer" , "text_encoder" ], ** kwargs ):
2960+         r""" 
2961+         Reverses the effect of 
2962+         [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). 
2963+ 
2964+         <Tip warning={true}> 
2965+ 
2966+         This is an experimental API. 
2967+ 
2968+         </Tip> 
2969+ 
2970+         Args: 
2971+             components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. 
2972+             unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. 
2973+             unfuse_text_encoder (`bool`, defaults to `True`): 
2974+                 Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the 
2975+                 LoRA parameters then it won't have any effect. 
2976+         """ 
2977+         super ().unfuse_lora (components = components )
2978+ 
2979+ 
26722980class  LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
26732981    def  __init__ (self , * args , ** kwargs ):
26742982        deprecation_message  =  "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." 
0 commit comments