Skip to content

Commit ecc1c18

Browse files
committed
initial commit to add HiDreamImageLoraLoaderMixin
1 parent aa6b6e2 commit ecc1c18

File tree

3 files changed

+338
-1
lines changed

3 files changed

+338
-1
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def text_encoder_attn_modules(text_encoder):
7676
"SanaLoraLoaderMixin",
7777
"Lumina2LoraLoaderMixin",
7878
"WanLoraLoaderMixin",
79+
"HiDreamImageLoraLoaderMixin"
7980
]
8081
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
8182
_import_structure["ip_adapter"] = [

src/diffusers/loaders/lora_pipeline.py

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5395,6 +5395,341 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
53955395
"""
53965396
super().unfuse_lora(components=components, **kwargs)
53975397

5398+
class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
5399+
r"""
5400+
Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
5401+
"""
5402+
5403+
_lora_loadable_modules = ["transformer"]
5404+
transformer_name = TRANSFORMER_NAME
5405+
5406+
@classmethod
5407+
@validate_hf_hub_args
5408+
def lora_state_dict(
5409+
cls,
5410+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
5411+
**kwargs,
5412+
):
5413+
r"""
5414+
Return state dict for lora weights and the network alphas.
5415+
5416+
<Tip warning={true}>
5417+
5418+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
5419+
5420+
This function is experimental and might change in the future.
5421+
5422+
</Tip>
5423+
5424+
Parameters:
5425+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
5426+
Can be either:
5427+
5428+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
5429+
the Hub.
5430+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
5431+
with [`ModelMixin.save_pretrained`].
5432+
- A [torch state
5433+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
5434+
5435+
cache_dir (`Union[str, os.PathLike]`, *optional*):
5436+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
5437+
is not used.
5438+
force_download (`bool`, *optional*, defaults to `False`):
5439+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
5440+
cached versions if they exist.
5441+
5442+
proxies (`Dict[str, str]`, *optional*):
5443+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
5444+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
5445+
local_files_only (`bool`, *optional*, defaults to `False`):
5446+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
5447+
won't be downloaded from the Hub.
5448+
token (`str` or *bool*, *optional*):
5449+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
5450+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
5451+
revision (`str`, *optional*, defaults to `"main"`):
5452+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
5453+
allowed by Git.
5454+
subfolder (`str`, *optional*, defaults to `""`):
5455+
The subfolder location of a model file within a larger model repository on the Hub or locally.
5456+
5457+
"""
5458+
# Load the main state dict first which has the LoRA layers for either of
5459+
# transformer and text encoder or both.
5460+
cache_dir = kwargs.pop("cache_dir", None)
5461+
force_download = kwargs.pop("force_download", False)
5462+
proxies = kwargs.pop("proxies", None)
5463+
local_files_only = kwargs.pop("local_files_only", None)
5464+
token = kwargs.pop("token", None)
5465+
revision = kwargs.pop("revision", None)
5466+
subfolder = kwargs.pop("subfolder", None)
5467+
weight_name = kwargs.pop("weight_name", None)
5468+
use_safetensors = kwargs.pop("use_safetensors", None)
5469+
5470+
allow_pickle = False
5471+
if use_safetensors is None:
5472+
use_safetensors = True
5473+
allow_pickle = True
5474+
5475+
user_agent = {
5476+
"file_type": "attn_procs_weights",
5477+
"framework": "pytorch",
5478+
}
5479+
5480+
state_dict = _fetch_state_dict(
5481+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
5482+
weight_name=weight_name,
5483+
use_safetensors=use_safetensors,
5484+
local_files_only=local_files_only,
5485+
cache_dir=cache_dir,
5486+
force_download=force_download,
5487+
proxies=proxies,
5488+
token=token,
5489+
revision=revision,
5490+
subfolder=subfolder,
5491+
user_agent=user_agent,
5492+
allow_pickle=allow_pickle,
5493+
)
5494+
5495+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
5496+
if is_dora_scale_present:
5497+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
5498+
logger.warning(warn_msg)
5499+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
5500+
5501+
# conversion.
5502+
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
5503+
if non_diffusers:
5504+
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
5505+
5506+
return state_dict
5507+
5508+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
5509+
def load_lora_weights(
5510+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
5511+
):
5512+
"""
5513+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
5514+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
5515+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
5516+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
5517+
dict is loaded into `self.transformer`.
5518+
5519+
Parameters:
5520+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
5521+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5522+
adapter_name (`str`, *optional*):
5523+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
5524+
`default_{i}` where i is the total number of adapters being loaded.
5525+
low_cpu_mem_usage (`bool`, *optional*):
5526+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5527+
weights.
5528+
kwargs (`dict`, *optional*):
5529+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5530+
"""
5531+
if not USE_PEFT_BACKEND:
5532+
raise ValueError("PEFT backend is required for this method.")
5533+
5534+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
5535+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5536+
raise ValueError(
5537+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5538+
)
5539+
5540+
# if a dict is passed, copy it instead of modifying it inplace
5541+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
5542+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
5543+
5544+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5545+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5546+
5547+
is_correct_format = all("lora" in key for key in state_dict.keys())
5548+
if not is_correct_format:
5549+
raise ValueError("Invalid LoRA checkpoint.")
5550+
5551+
self.load_lora_into_transformer(
5552+
state_dict,
5553+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5554+
adapter_name=adapter_name,
5555+
_pipeline=self,
5556+
low_cpu_mem_usage=low_cpu_mem_usage,
5557+
)
5558+
5559+
@classmethod
5560+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
5561+
def load_lora_into_transformer(
5562+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
5563+
):
5564+
"""
5565+
This will load the LoRA layers specified in `state_dict` into `transformer`.
5566+
5567+
Parameters:
5568+
state_dict (`dict`):
5569+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
5570+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
5571+
encoder lora layers.
5572+
transformer (`HiDreamImageTransformer2DModel`):
5573+
The Transformer model to load the LoRA layers into.
5574+
adapter_name (`str`, *optional*):
5575+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
5576+
`default_{i}` where i is the total number of adapters being loaded.
5577+
low_cpu_mem_usage (`bool`, *optional*):
5578+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5579+
weights.
5580+
hotswap : (`bool`, *optional*)
5581+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
5582+
in-place. This means that, instead of loading an additional adapter, this will take the existing
5583+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
5584+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
5585+
torch.compile, loading the new adapter does not require recompilation of the model. When using
5586+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
5587+
5588+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
5589+
to call an additional method before loading the adapter:
5590+
5591+
```py
5592+
pipeline = ... # load diffusers pipeline
5593+
max_rank = ... # the highest rank among all LoRAs that you want to load
5594+
# call *before* compiling and loading the LoRA adapter
5595+
pipeline.enable_lora_hotswap(target_rank=max_rank)
5596+
pipeline.load_lora_weights(file_name)
5597+
# optionally compile the model now
5598+
```
5599+
5600+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
5601+
limitations to this technique, which are documented here:
5602+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
5603+
"""
5604+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5605+
raise ValueError(
5606+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5607+
)
5608+
5609+
# Load the layers corresponding to transformer.
5610+
logger.info(f"Loading {cls.transformer_name}.")
5611+
transformer.load_lora_adapter(
5612+
state_dict,
5613+
network_alphas=None,
5614+
adapter_name=adapter_name,
5615+
_pipeline=_pipeline,
5616+
low_cpu_mem_usage=low_cpu_mem_usage,
5617+
hotswap=hotswap,
5618+
)
5619+
5620+
@classmethod
5621+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
5622+
def save_lora_weights(
5623+
cls,
5624+
save_directory: Union[str, os.PathLike],
5625+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
5626+
is_main_process: bool = True,
5627+
weight_name: str = None,
5628+
save_function: Callable = None,
5629+
safe_serialization: bool = True,
5630+
):
5631+
r"""
5632+
Save the LoRA parameters corresponding to the UNet and text encoder.
5633+
5634+
Arguments:
5635+
save_directory (`str` or `os.PathLike`):
5636+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
5637+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
5638+
State dict of the LoRA layers corresponding to the `transformer`.
5639+
is_main_process (`bool`, *optional*, defaults to `True`):
5640+
Whether the process calling this is the main process or not. Useful during distributed training and you
5641+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
5642+
process to avoid race conditions.
5643+
save_function (`Callable`):
5644+
The function to use to save the state dictionary. Useful during distributed training when you need to
5645+
replace `torch.save` with another method. Can be configured with the environment variable
5646+
`DIFFUSERS_SAVE_MODE`.
5647+
safe_serialization (`bool`, *optional*, defaults to `True`):
5648+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
5649+
"""
5650+
state_dict = {}
5651+
5652+
if not transformer_lora_layers:
5653+
raise ValueError("You must pass `transformer_lora_layers`.")
5654+
5655+
if transformer_lora_layers:
5656+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5657+
5658+
# Save the model
5659+
cls.write_lora_layers(
5660+
state_dict=state_dict,
5661+
save_directory=save_directory,
5662+
is_main_process=is_main_process,
5663+
weight_name=weight_name,
5664+
save_function=save_function,
5665+
safe_serialization=safe_serialization,
5666+
)
5667+
5668+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
5669+
def fuse_lora(
5670+
self,
5671+
components: List[str] = ["transformer"],
5672+
lora_scale: float = 1.0,
5673+
safe_fusing: bool = False,
5674+
adapter_names: Optional[List[str]] = None,
5675+
**kwargs,
5676+
):
5677+
r"""
5678+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
5679+
5680+
<Tip warning={true}>
5681+
5682+
This is an experimental API.
5683+
5684+
</Tip>
5685+
5686+
Args:
5687+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
5688+
lora_scale (`float`, defaults to 1.0):
5689+
Controls how much to influence the outputs with the LoRA parameters.
5690+
safe_fusing (`bool`, defaults to `False`):
5691+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
5692+
adapter_names (`List[str]`, *optional*):
5693+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
5694+
5695+
Example:
5696+
5697+
```py
5698+
from diffusers import DiffusionPipeline
5699+
import torch
5700+
5701+
pipeline = DiffusionPipeline.from_pretrained(
5702+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
5703+
).to("cuda")
5704+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
5705+
pipeline.fuse_lora(lora_scale=0.7)
5706+
```
5707+
"""
5708+
super().fuse_lora(
5709+
components=components,
5710+
lora_scale=lora_scale,
5711+
safe_fusing=safe_fusing,
5712+
adapter_names=adapter_names,
5713+
**kwargs,
5714+
)
5715+
5716+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
5717+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
5718+
r"""
5719+
Reverses the effect of
5720+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
5721+
5722+
<Tip warning={true}>
5723+
5724+
This is an experimental API.
5725+
5726+
</Tip>
5727+
5728+
Args:
5729+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
5730+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
5731+
"""
5732+
super().unfuse_lora(components=components, **kwargs)
53985733

53995734
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
54005735
def __init__(self, *args, **kwargs):

src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414

1515
from ...image_processor import VaeImageProcessor
16+
from ...loaders import HiDreamImageLoraLoaderMixin
1617
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
1718
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
1819
from ...utils import is_torch_xla_available, logging
@@ -151,7 +152,7 @@ def retrieve_timesteps(
151152
return timesteps, num_inference_steps
152153

153154

154-
class HiDreamImagePipeline(DiffusionPipeline):
155+
class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
155156
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
156157
_callback_tensor_inputs = ["latents", "prompt_embeds"]
157158

0 commit comments

Comments
 (0)