@@ -3290,6 +3290,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
32903290 super ().unfuse_lora (components = components )
32913291
32923292
3293+ class LTXVideoLoraLoaderMixin (LoraBaseMixin ):
3294+ r"""
3295+ Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
3296+ """
3297+
3298+ _lora_loadable_modules = ["transformer" ]
3299+ transformer_name = TRANSFORMER_NAME
3300+
3301+ @classmethod
3302+ @validate_hf_hub_args
3303+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
3304+ def lora_state_dict (
3305+ cls ,
3306+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3307+ ** kwargs ,
3308+ ):
3309+ r"""
3310+ Return state dict for lora weights and the network alphas.
3311+
3312+ <Tip warning={true}>
3313+
3314+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3315+
3316+ This function is experimental and might change in the future.
3317+
3318+ </Tip>
3319+
3320+ Parameters:
3321+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3322+ Can be either:
3323+
3324+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3325+ the Hub.
3326+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3327+ with [`ModelMixin.save_pretrained`].
3328+ - A [torch state
3329+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3330+
3331+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3332+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3333+ is not used.
3334+ force_download (`bool`, *optional*, defaults to `False`):
3335+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3336+ cached versions if they exist.
3337+
3338+ proxies (`Dict[str, str]`, *optional*):
3339+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3340+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3341+ local_files_only (`bool`, *optional*, defaults to `False`):
3342+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3343+ won't be downloaded from the Hub.
3344+ token (`str` or *bool*, *optional*):
3345+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3346+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3347+ revision (`str`, *optional*, defaults to `"main"`):
3348+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3349+ allowed by Git.
3350+ subfolder (`str`, *optional*, defaults to `""`):
3351+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3352+
3353+ """
3354+ # Load the main state dict first which has the LoRA layers for either of
3355+ # transformer and text encoder or both.
3356+ cache_dir = kwargs .pop ("cache_dir" , None )
3357+ force_download = kwargs .pop ("force_download" , False )
3358+ proxies = kwargs .pop ("proxies" , None )
3359+ local_files_only = kwargs .pop ("local_files_only" , None )
3360+ token = kwargs .pop ("token" , None )
3361+ revision = kwargs .pop ("revision" , None )
3362+ subfolder = kwargs .pop ("subfolder" , None )
3363+ weight_name = kwargs .pop ("weight_name" , None )
3364+ use_safetensors = kwargs .pop ("use_safetensors" , None )
3365+
3366+ allow_pickle = False
3367+ if use_safetensors is None :
3368+ use_safetensors = True
3369+ allow_pickle = True
3370+
3371+ user_agent = {
3372+ "file_type" : "attn_procs_weights" ,
3373+ "framework" : "pytorch" ,
3374+ }
3375+
3376+ state_dict = _fetch_state_dict (
3377+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
3378+ weight_name = weight_name ,
3379+ use_safetensors = use_safetensors ,
3380+ local_files_only = local_files_only ,
3381+ cache_dir = cache_dir ,
3382+ force_download = force_download ,
3383+ proxies = proxies ,
3384+ token = token ,
3385+ revision = revision ,
3386+ subfolder = subfolder ,
3387+ user_agent = user_agent ,
3388+ allow_pickle = allow_pickle ,
3389+ )
3390+
3391+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
3392+ if is_dora_scale_present :
3393+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3394+ logger .warning (warn_msg )
3395+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
3396+
3397+ return state_dict
3398+
3399+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3400+ def load_lora_weights (
3401+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
3402+ ):
3403+ """
3404+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3405+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3406+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3407+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3408+ dict is loaded into `self.transformer`.
3409+
3410+ Parameters:
3411+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3412+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3413+ adapter_name (`str`, *optional*):
3414+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3415+ `default_{i}` where i is the total number of adapters being loaded.
3416+ low_cpu_mem_usage (`bool`, *optional*):
3417+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3418+ weights.
3419+ kwargs (`dict`, *optional*):
3420+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3421+ """
3422+ if not USE_PEFT_BACKEND :
3423+ raise ValueError ("PEFT backend is required for this method." )
3424+
3425+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
3426+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
3427+ raise ValueError (
3428+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3429+ )
3430+
3431+ # if a dict is passed, copy it instead of modifying it inplace
3432+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
3433+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
3434+
3435+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3436+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3437+
3438+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
3439+ if not is_correct_format :
3440+ raise ValueError ("Invalid LoRA checkpoint." )
3441+
3442+ self .load_lora_into_transformer (
3443+ state_dict ,
3444+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
3445+ adapter_name = adapter_name ,
3446+ _pipeline = self ,
3447+ low_cpu_mem_usage = low_cpu_mem_usage ,
3448+ )
3449+
3450+ @classmethod
3451+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
3452+ def load_lora_into_transformer (
3453+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
3454+ ):
3455+ """
3456+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3457+
3458+ Parameters:
3459+ state_dict (`dict`):
3460+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3461+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3462+ encoder lora layers.
3463+ transformer (`LTXVideoTransformer3DModel`):
3464+ The Transformer model to load the LoRA layers into.
3465+ adapter_name (`str`, *optional*):
3466+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3467+ `default_{i}` where i is the total number of adapters being loaded.
3468+ low_cpu_mem_usage (`bool`, *optional*):
3469+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3470+ weights.
3471+ """
3472+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
3473+ raise ValueError (
3474+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3475+ )
3476+
3477+ # Load the layers corresponding to transformer.
3478+ logger .info (f"Loading { cls .transformer_name } ." )
3479+ transformer .load_lora_adapter (
3480+ state_dict ,
3481+ network_alphas = None ,
3482+ adapter_name = adapter_name ,
3483+ _pipeline = _pipeline ,
3484+ low_cpu_mem_usage = low_cpu_mem_usage ,
3485+ )
3486+
3487+ @classmethod
3488+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3489+ def save_lora_weights (
3490+ cls ,
3491+ save_directory : Union [str , os .PathLike ],
3492+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
3493+ is_main_process : bool = True ,
3494+ weight_name : str = None ,
3495+ save_function : Callable = None ,
3496+ safe_serialization : bool = True ,
3497+ ):
3498+ r"""
3499+ Save the LoRA parameters corresponding to the UNet and text encoder.
3500+
3501+ Arguments:
3502+ save_directory (`str` or `os.PathLike`):
3503+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3504+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3505+ State dict of the LoRA layers corresponding to the `transformer`.
3506+ is_main_process (`bool`, *optional*, defaults to `True`):
3507+ Whether the process calling this is the main process or not. Useful during distributed training and you
3508+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3509+ process to avoid race conditions.
3510+ save_function (`Callable`):
3511+ The function to use to save the state dictionary. Useful during distributed training when you need to
3512+ replace `torch.save` with another method. Can be configured with the environment variable
3513+ `DIFFUSERS_SAVE_MODE`.
3514+ safe_serialization (`bool`, *optional*, defaults to `True`):
3515+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3516+ """
3517+ state_dict = {}
3518+
3519+ if not transformer_lora_layers :
3520+ raise ValueError ("You must pass `transformer_lora_layers`." )
3521+
3522+ if transformer_lora_layers :
3523+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
3524+
3525+ # Save the model
3526+ cls .write_lora_layers (
3527+ state_dict = state_dict ,
3528+ save_directory = save_directory ,
3529+ is_main_process = is_main_process ,
3530+ weight_name = weight_name ,
3531+ save_function = save_function ,
3532+ safe_serialization = safe_serialization ,
3533+ )
3534+
3535+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3536+ def fuse_lora (
3537+ self ,
3538+ components : List [str ] = ["transformer" , "text_encoder" ],
3539+ lora_scale : float = 1.0 ,
3540+ safe_fusing : bool = False ,
3541+ adapter_names : Optional [List [str ]] = None ,
3542+ ** kwargs ,
3543+ ):
3544+ r"""
3545+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3546+
3547+ <Tip warning={true}>
3548+
3549+ This is an experimental API.
3550+
3551+ </Tip>
3552+
3553+ Args:
3554+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3555+ lora_scale (`float`, defaults to 1.0):
3556+ Controls how much to influence the outputs with the LoRA parameters.
3557+ safe_fusing (`bool`, defaults to `False`):
3558+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3559+ adapter_names (`List[str]`, *optional*):
3560+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3561+
3562+ Example:
3563+
3564+ ```py
3565+ from diffusers import DiffusionPipeline
3566+ import torch
3567+
3568+ pipeline = DiffusionPipeline.from_pretrained(
3569+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3570+ ).to("cuda")
3571+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3572+ pipeline.fuse_lora(lora_scale=0.7)
3573+ ```
3574+ """
3575+ super ().fuse_lora (
3576+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
3577+ )
3578+
3579+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3580+ def unfuse_lora (self , components : List [str ] = ["transformer" , "text_encoder" ], ** kwargs ):
3581+ r"""
3582+ Reverses the effect of
3583+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3584+
3585+ <Tip warning={true}>
3586+
3587+ This is an experimental API.
3588+
3589+ </Tip>
3590+
3591+ Args:
3592+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3593+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3594+ unfuse_text_encoder (`bool`, defaults to `True`):
3595+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3596+ LoRA parameters then it won't have any effect.
3597+ """
3598+ super ().unfuse_lora (components = components )
3599+
3600+
32933601class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
32943602 def __init__ (self , * args , ** kwargs ):
32953603 deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments