@@ -4115,6 +4115,311 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
41154115 super ().unfuse_lora (components = components )
41164116
41174117
4118+ class WanLoraLoaderMixin (LoraBaseMixin ):
4119+ r"""
4120+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
4121+ """
4122+
4123+ _lora_loadable_modules = ["transformer" ]
4124+ transformer_name = TRANSFORMER_NAME
4125+
4126+ @classmethod
4127+ @validate_hf_hub_args
4128+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
4129+ def lora_state_dict (
4130+ cls ,
4131+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
4132+ ** kwargs ,
4133+ ):
4134+ r"""
4135+ Return state dict for lora weights and the network alphas.
4136+
4137+ <Tip warning={true}>
4138+
4139+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4140+
4141+ This function is experimental and might change in the future.
4142+
4143+ </Tip>
4144+
4145+ Parameters:
4146+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4147+ Can be either:
4148+
4149+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
4150+ the Hub.
4151+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
4152+ with [`ModelMixin.save_pretrained`].
4153+ - A [torch state
4154+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
4155+
4156+ cache_dir (`Union[str, os.PathLike]`, *optional*):
4157+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
4158+ is not used.
4159+ force_download (`bool`, *optional*, defaults to `False`):
4160+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
4161+ cached versions if they exist.
4162+
4163+ proxies (`Dict[str, str]`, *optional*):
4164+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
4165+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
4166+ local_files_only (`bool`, *optional*, defaults to `False`):
4167+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
4168+ won't be downloaded from the Hub.
4169+ token (`str` or *bool*, *optional*):
4170+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
4171+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
4172+ revision (`str`, *optional*, defaults to `"main"`):
4173+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
4174+ allowed by Git.
4175+ subfolder (`str`, *optional*, defaults to `""`):
4176+ The subfolder location of a model file within a larger model repository on the Hub or locally.
4177+
4178+ """
4179+ # Load the main state dict first which has the LoRA layers for either of
4180+ # transformer and text encoder or both.
4181+ cache_dir = kwargs .pop ("cache_dir" , None )
4182+ force_download = kwargs .pop ("force_download" , False )
4183+ proxies = kwargs .pop ("proxies" , None )
4184+ local_files_only = kwargs .pop ("local_files_only" , None )
4185+ token = kwargs .pop ("token" , None )
4186+ revision = kwargs .pop ("revision" , None )
4187+ subfolder = kwargs .pop ("subfolder" , None )
4188+ weight_name = kwargs .pop ("weight_name" , None )
4189+ use_safetensors = kwargs .pop ("use_safetensors" , None )
4190+
4191+ allow_pickle = False
4192+ if use_safetensors is None :
4193+ use_safetensors = True
4194+ allow_pickle = True
4195+
4196+ user_agent = {
4197+ "file_type" : "attn_procs_weights" ,
4198+ "framework" : "pytorch" ,
4199+ }
4200+
4201+ state_dict = _fetch_state_dict (
4202+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
4203+ weight_name = weight_name ,
4204+ use_safetensors = use_safetensors ,
4205+ local_files_only = local_files_only ,
4206+ cache_dir = cache_dir ,
4207+ force_download = force_download ,
4208+ proxies = proxies ,
4209+ token = token ,
4210+ revision = revision ,
4211+ subfolder = subfolder ,
4212+ user_agent = user_agent ,
4213+ allow_pickle = allow_pickle ,
4214+ )
4215+
4216+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
4217+ if is_dora_scale_present :
4218+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4219+ logger .warning (warn_msg )
4220+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
4221+
4222+ return state_dict
4223+
4224+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4225+ def load_lora_weights (
4226+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
4227+ ):
4228+ """
4229+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4230+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4231+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4232+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4233+ dict is loaded into `self.transformer`.
4234+
4235+ Parameters:
4236+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4237+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4238+ adapter_name (`str`, *optional*):
4239+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4240+ `default_{i}` where i is the total number of adapters being loaded.
4241+ low_cpu_mem_usage (`bool`, *optional*):
4242+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4243+ weights.
4244+ kwargs (`dict`, *optional*):
4245+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4246+ """
4247+ if not USE_PEFT_BACKEND :
4248+ raise ValueError ("PEFT backend is required for this method." )
4249+
4250+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
4251+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
4252+ raise ValueError (
4253+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4254+ )
4255+
4256+ # if a dict is passed, copy it instead of modifying it inplace
4257+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
4258+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
4259+
4260+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4261+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4262+
4263+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
4264+ if not is_correct_format :
4265+ raise ValueError ("Invalid LoRA checkpoint." )
4266+
4267+ self .load_lora_into_transformer (
4268+ state_dict ,
4269+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
4270+ adapter_name = adapter_name ,
4271+ _pipeline = self ,
4272+ low_cpu_mem_usage = low_cpu_mem_usage ,
4273+ )
4274+
4275+ @classmethod
4276+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
4277+ def load_lora_into_transformer (
4278+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
4279+ ):
4280+ """
4281+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4282+
4283+ Parameters:
4284+ state_dict (`dict`):
4285+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4286+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4287+ encoder lora layers.
4288+ transformer (`WanTransformer3DModel`):
4289+ The Transformer model to load the LoRA layers into.
4290+ adapter_name (`str`, *optional*):
4291+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4292+ `default_{i}` where i is the total number of adapters being loaded.
4293+ low_cpu_mem_usage (`bool`, *optional*):
4294+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4295+ weights.
4296+ """
4297+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
4298+ raise ValueError (
4299+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4300+ )
4301+
4302+ # Load the layers corresponding to transformer.
4303+ logger .info (f"Loading { cls .transformer_name } ." )
4304+ transformer .load_lora_adapter (
4305+ state_dict ,
4306+ network_alphas = None ,
4307+ adapter_name = adapter_name ,
4308+ _pipeline = _pipeline ,
4309+ low_cpu_mem_usage = low_cpu_mem_usage ,
4310+ )
4311+
4312+ @classmethod
4313+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4314+ def save_lora_weights (
4315+ cls ,
4316+ save_directory : Union [str , os .PathLike ],
4317+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
4318+ is_main_process : bool = True ,
4319+ weight_name : str = None ,
4320+ save_function : Callable = None ,
4321+ safe_serialization : bool = True ,
4322+ ):
4323+ r"""
4324+ Save the LoRA parameters corresponding to the UNet and text encoder.
4325+
4326+ Arguments:
4327+ save_directory (`str` or `os.PathLike`):
4328+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4329+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4330+ State dict of the LoRA layers corresponding to the `transformer`.
4331+ is_main_process (`bool`, *optional*, defaults to `True`):
4332+ Whether the process calling this is the main process or not. Useful during distributed training and you
4333+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4334+ process to avoid race conditions.
4335+ save_function (`Callable`):
4336+ The function to use to save the state dictionary. Useful during distributed training when you need to
4337+ replace `torch.save` with another method. Can be configured with the environment variable
4338+ `DIFFUSERS_SAVE_MODE`.
4339+ safe_serialization (`bool`, *optional*, defaults to `True`):
4340+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4341+ """
4342+ state_dict = {}
4343+
4344+ if not transformer_lora_layers :
4345+ raise ValueError ("You must pass `transformer_lora_layers`." )
4346+
4347+ if transformer_lora_layers :
4348+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
4349+
4350+ # Save the model
4351+ cls .write_lora_layers (
4352+ state_dict = state_dict ,
4353+ save_directory = save_directory ,
4354+ is_main_process = is_main_process ,
4355+ weight_name = weight_name ,
4356+ save_function = save_function ,
4357+ safe_serialization = safe_serialization ,
4358+ )
4359+
4360+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
4361+ def fuse_lora (
4362+ self ,
4363+ components : List [str ] = ["transformer" ],
4364+ lora_scale : float = 1.0 ,
4365+ safe_fusing : bool = False ,
4366+ adapter_names : Optional [List [str ]] = None ,
4367+ ** kwargs ,
4368+ ):
4369+ r"""
4370+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4371+
4372+ <Tip warning={true}>
4373+
4374+ This is an experimental API.
4375+
4376+ </Tip>
4377+
4378+ Args:
4379+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4380+ lora_scale (`float`, defaults to 1.0):
4381+ Controls how much to influence the outputs with the LoRA parameters.
4382+ safe_fusing (`bool`, defaults to `False`):
4383+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4384+ adapter_names (`List[str]`, *optional*):
4385+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4386+
4387+ Example:
4388+
4389+ ```py
4390+ from diffusers import DiffusionPipeline
4391+ import torch
4392+
4393+ pipeline = DiffusionPipeline.from_pretrained(
4394+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4395+ ).to("cuda")
4396+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4397+ pipeline.fuse_lora(lora_scale=0.7)
4398+ ```
4399+ """
4400+ super ().fuse_lora (
4401+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
4402+ )
4403+
4404+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
4405+ def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
4406+ r"""
4407+ Reverses the effect of
4408+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4409+
4410+ <Tip warning={true}>
4411+
4412+ This is an experimental API.
4413+
4414+ </Tip>
4415+
4416+ Args:
4417+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4418+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4419+ """
4420+ super ().unfuse_lora (components = components )
4421+
4422+
41184423class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
41194424 def __init__ (self , * args , ** kwargs ):
41204425 deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments