@@ -5395,6 +5395,341 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
53955395 """
53965396 super ().unfuse_lora (components = components , ** kwargs )
53975397
5398+ class HiDreamImageLoraLoaderMixin (LoraBaseMixin ):
5399+ r"""
5400+ Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
5401+ """
5402+
5403+ _lora_loadable_modules = ["transformer" ]
5404+ transformer_name = TRANSFORMER_NAME
5405+
5406+ @classmethod
5407+ @validate_hf_hub_args
5408+ def lora_state_dict (
5409+ cls ,
5410+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
5411+ ** kwargs ,
5412+ ):
5413+ r"""
5414+ Return state dict for lora weights and the network alphas.
5415+
5416+ <Tip warning={true}>
5417+
5418+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
5419+
5420+ This function is experimental and might change in the future.
5421+
5422+ </Tip>
5423+
5424+ Parameters:
5425+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
5426+ Can be either:
5427+
5428+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
5429+ the Hub.
5430+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
5431+ with [`ModelMixin.save_pretrained`].
5432+ - A [torch state
5433+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
5434+
5435+ cache_dir (`Union[str, os.PathLike]`, *optional*):
5436+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
5437+ is not used.
5438+ force_download (`bool`, *optional*, defaults to `False`):
5439+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
5440+ cached versions if they exist.
5441+
5442+ proxies (`Dict[str, str]`, *optional*):
5443+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
5444+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
5445+ local_files_only (`bool`, *optional*, defaults to `False`):
5446+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
5447+ won't be downloaded from the Hub.
5448+ token (`str` or *bool*, *optional*):
5449+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
5450+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
5451+ revision (`str`, *optional*, defaults to `"main"`):
5452+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
5453+ allowed by Git.
5454+ subfolder (`str`, *optional*, defaults to `""`):
5455+ The subfolder location of a model file within a larger model repository on the Hub or locally.
5456+
5457+ """
5458+ # Load the main state dict first which has the LoRA layers for either of
5459+ # transformer and text encoder or both.
5460+ cache_dir = kwargs .pop ("cache_dir" , None )
5461+ force_download = kwargs .pop ("force_download" , False )
5462+ proxies = kwargs .pop ("proxies" , None )
5463+ local_files_only = kwargs .pop ("local_files_only" , None )
5464+ token = kwargs .pop ("token" , None )
5465+ revision = kwargs .pop ("revision" , None )
5466+ subfolder = kwargs .pop ("subfolder" , None )
5467+ weight_name = kwargs .pop ("weight_name" , None )
5468+ use_safetensors = kwargs .pop ("use_safetensors" , None )
5469+
5470+ allow_pickle = False
5471+ if use_safetensors is None :
5472+ use_safetensors = True
5473+ allow_pickle = True
5474+
5475+ user_agent = {
5476+ "file_type" : "attn_procs_weights" ,
5477+ "framework" : "pytorch" ,
5478+ }
5479+
5480+ state_dict = _fetch_state_dict (
5481+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
5482+ weight_name = weight_name ,
5483+ use_safetensors = use_safetensors ,
5484+ local_files_only = local_files_only ,
5485+ cache_dir = cache_dir ,
5486+ force_download = force_download ,
5487+ proxies = proxies ,
5488+ token = token ,
5489+ revision = revision ,
5490+ subfolder = subfolder ,
5491+ user_agent = user_agent ,
5492+ allow_pickle = allow_pickle ,
5493+ )
5494+
5495+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
5496+ if is_dora_scale_present :
5497+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
5498+ logger .warning (warn_msg )
5499+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
5500+
5501+ # conversion.
5502+ non_diffusers = any (k .startswith ("diffusion_model." ) for k in state_dict )
5503+ if non_diffusers :
5504+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers (state_dict )
5505+
5506+ return state_dict
5507+
5508+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
5509+ def load_lora_weights (
5510+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
5511+ ):
5512+ """
5513+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
5514+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
5515+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
5516+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
5517+ dict is loaded into `self.transformer`.
5518+
5519+ Parameters:
5520+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
5521+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5522+ adapter_name (`str`, *optional*):
5523+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
5524+ `default_{i}` where i is the total number of adapters being loaded.
5525+ low_cpu_mem_usage (`bool`, *optional*):
5526+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5527+ weights.
5528+ kwargs (`dict`, *optional*):
5529+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5530+ """
5531+ if not USE_PEFT_BACKEND :
5532+ raise ValueError ("PEFT backend is required for this method." )
5533+
5534+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
5535+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
5536+ raise ValueError (
5537+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5538+ )
5539+
5540+ # if a dict is passed, copy it instead of modifying it inplace
5541+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
5542+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
5543+
5544+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5545+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
5546+
5547+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
5548+ if not is_correct_format :
5549+ raise ValueError ("Invalid LoRA checkpoint." )
5550+
5551+ self .load_lora_into_transformer (
5552+ state_dict ,
5553+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
5554+ adapter_name = adapter_name ,
5555+ _pipeline = self ,
5556+ low_cpu_mem_usage = low_cpu_mem_usage ,
5557+ )
5558+
5559+ @classmethod
5560+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
5561+ def load_lora_into_transformer (
5562+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False , hotswap : bool = False
5563+ ):
5564+ """
5565+ This will load the LoRA layers specified in `state_dict` into `transformer`.
5566+
5567+ Parameters:
5568+ state_dict (`dict`):
5569+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
5570+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
5571+ encoder lora layers.
5572+ transformer (`HiDreamImageTransformer2DModel`):
5573+ The Transformer model to load the LoRA layers into.
5574+ adapter_name (`str`, *optional*):
5575+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
5576+ `default_{i}` where i is the total number of adapters being loaded.
5577+ low_cpu_mem_usage (`bool`, *optional*):
5578+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5579+ weights.
5580+ hotswap : (`bool`, *optional*)
5581+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
5582+ in-place. This means that, instead of loading an additional adapter, this will take the existing
5583+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
5584+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
5585+ torch.compile, loading the new adapter does not require recompilation of the model. When using
5586+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
5587+
5588+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
5589+ to call an additional method before loading the adapter:
5590+
5591+ ```py
5592+ pipeline = ... # load diffusers pipeline
5593+ max_rank = ... # the highest rank among all LoRAs that you want to load
5594+ # call *before* compiling and loading the LoRA adapter
5595+ pipeline.enable_lora_hotswap(target_rank=max_rank)
5596+ pipeline.load_lora_weights(file_name)
5597+ # optionally compile the model now
5598+ ```
5599+
5600+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
5601+ limitations to this technique, which are documented here:
5602+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
5603+ """
5604+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
5605+ raise ValueError (
5606+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5607+ )
5608+
5609+ # Load the layers corresponding to transformer.
5610+ logger .info (f"Loading { cls .transformer_name } ." )
5611+ transformer .load_lora_adapter (
5612+ state_dict ,
5613+ network_alphas = None ,
5614+ adapter_name = adapter_name ,
5615+ _pipeline = _pipeline ,
5616+ low_cpu_mem_usage = low_cpu_mem_usage ,
5617+ hotswap = hotswap ,
5618+ )
5619+
5620+ @classmethod
5621+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
5622+ def save_lora_weights (
5623+ cls ,
5624+ save_directory : Union [str , os .PathLike ],
5625+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
5626+ is_main_process : bool = True ,
5627+ weight_name : str = None ,
5628+ save_function : Callable = None ,
5629+ safe_serialization : bool = True ,
5630+ ):
5631+ r"""
5632+ Save the LoRA parameters corresponding to the UNet and text encoder.
5633+
5634+ Arguments:
5635+ save_directory (`str` or `os.PathLike`):
5636+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
5637+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
5638+ State dict of the LoRA layers corresponding to the `transformer`.
5639+ is_main_process (`bool`, *optional*, defaults to `True`):
5640+ Whether the process calling this is the main process or not. Useful during distributed training and you
5641+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
5642+ process to avoid race conditions.
5643+ save_function (`Callable`):
5644+ The function to use to save the state dictionary. Useful during distributed training when you need to
5645+ replace `torch.save` with another method. Can be configured with the environment variable
5646+ `DIFFUSERS_SAVE_MODE`.
5647+ safe_serialization (`bool`, *optional*, defaults to `True`):
5648+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
5649+ """
5650+ state_dict = {}
5651+
5652+ if not transformer_lora_layers :
5653+ raise ValueError ("You must pass `transformer_lora_layers`." )
5654+
5655+ if transformer_lora_layers :
5656+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
5657+
5658+ # Save the model
5659+ cls .write_lora_layers (
5660+ state_dict = state_dict ,
5661+ save_directory = save_directory ,
5662+ is_main_process = is_main_process ,
5663+ weight_name = weight_name ,
5664+ save_function = save_function ,
5665+ safe_serialization = safe_serialization ,
5666+ )
5667+
5668+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
5669+ def fuse_lora (
5670+ self ,
5671+ components : List [str ] = ["transformer" ],
5672+ lora_scale : float = 1.0 ,
5673+ safe_fusing : bool = False ,
5674+ adapter_names : Optional [List [str ]] = None ,
5675+ ** kwargs ,
5676+ ):
5677+ r"""
5678+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
5679+
5680+ <Tip warning={true}>
5681+
5682+ This is an experimental API.
5683+
5684+ </Tip>
5685+
5686+ Args:
5687+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
5688+ lora_scale (`float`, defaults to 1.0):
5689+ Controls how much to influence the outputs with the LoRA parameters.
5690+ safe_fusing (`bool`, defaults to `False`):
5691+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
5692+ adapter_names (`List[str]`, *optional*):
5693+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
5694+
5695+ Example:
5696+
5697+ ```py
5698+ from diffusers import DiffusionPipeline
5699+ import torch
5700+
5701+ pipeline = DiffusionPipeline.from_pretrained(
5702+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
5703+ ).to("cuda")
5704+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
5705+ pipeline.fuse_lora(lora_scale=0.7)
5706+ ```
5707+ """
5708+ super ().fuse_lora (
5709+ components = components ,
5710+ lora_scale = lora_scale ,
5711+ safe_fusing = safe_fusing ,
5712+ adapter_names = adapter_names ,
5713+ ** kwargs ,
5714+ )
5715+
5716+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
5717+ def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
5718+ r"""
5719+ Reverses the effect of
5720+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
5721+
5722+ <Tip warning={true}>
5723+
5724+ This is an experimental API.
5725+
5726+ </Tip>
5727+
5728+ Args:
5729+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
5730+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
5731+ """
5732+ super ().unfuse_lora (components = components , ** kwargs )
53985733
53995734class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
54005735 def __init__ (self , * args , ** kwargs ):
0 commit comments