Skip to content

Commit 18b2679

Browse files
authored
Merge branch 'main' into sd3-img2img-ipadapter
2 parents cd4eaf5 + c944f06 commit 18b2679

File tree

13 files changed

+261
-24
lines changed

13 files changed

+261
-24
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..utils import (
2222
USE_PEFT_BACKEND,
2323
deprecate,
24+
get_submodule_by_name,
2425
is_peft_available,
2526
is_peft_version,
2627
is_torch_version,
@@ -1981,10 +1982,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
19811982
in_features = state_dict[lora_A_weight_name].shape[1]
19821983
out_features = state_dict[lora_B_weight_name].shape[0]
19831984

1985+
# Model maybe loaded with different quantization schemes which may flatten the params.
1986+
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
1987+
# preserve weight shape.
1988+
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
1989+
19841990
# This means there's no need for an expansion in the params, so we simply skip.
1985-
if tuple(module_weight.shape) == (out_features, in_features):
1991+
if tuple(module_weight_shape) == (out_features, in_features):
19861992
continue
19871993

1994+
# TODO (sayakpaul): We still need to consider if the module we're expanding is
1995+
# quantized and handle it accordingly if that is the case.
19881996
module_out_features, module_in_features = module_weight.shape
19891997
debug_message = ""
19901998
if in_features > module_in_features:
@@ -2080,13 +2088,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20802088
base_weight_param = transformer_state_dict[base_param_name]
20812089
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
20822090

2083-
if base_weight_param.shape[1] > lora_A_param.shape[1]:
2091+
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
2092+
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
2093+
2094+
if base_module_shape[1] > lora_A_param.shape[1]:
20842095
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
20852096
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
20862097
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
20872098
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
20882099
expanded_module_names.add(k)
2089-
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
2100+
elif base_module_shape[1] < lora_A_param.shape[1]:
20902101
raise NotImplementedError(
20912102
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
20922103
)
@@ -2098,6 +2109,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20982109

20992110
return lora_state_dict
21002111

2112+
@staticmethod
2113+
def _calculate_module_shape(
2114+
model: "torch.nn.Module",
2115+
base_module: "torch.nn.Linear" = None,
2116+
base_weight_param_name: str = None,
2117+
) -> "torch.Size":
2118+
def _get_weight_shape(weight: torch.Tensor):
2119+
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
2120+
2121+
if base_module is not None:
2122+
return _get_weight_shape(base_module.weight)
2123+
elif base_weight_param_name is not None:
2124+
if not base_weight_param_name.endswith(".weight"):
2125+
raise ValueError(
2126+
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
2127+
)
2128+
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
2129+
submodule = get_submodule_by_name(model, module_path)
2130+
return _get_weight_shape(submodule.weight)
2131+
2132+
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
2133+
21012134

21022135
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
21032136
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

src/diffusers/loaders/textual_inversion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
4040
force_download = kwargs.pop("force_download", False)
4141
proxies = kwargs.pop("proxies", None)
4242
local_files_only = kwargs.pop("local_files_only", None)
43-
token = kwargs.pop("token", None)
43+
hf_token = kwargs.pop("hf_token", None)
4444
revision = kwargs.pop("revision", None)
4545
subfolder = kwargs.pop("subfolder", None)
4646
weight_name = kwargs.pop("weight_name", None)
@@ -73,7 +73,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
7373
force_download=force_download,
7474
proxies=proxies,
7575
local_files_only=local_files_only,
76-
token=token,
76+
token=hf_token,
7777
revision=revision,
7878
subfolder=subfolder,
7979
user_agent=user_agent,
@@ -93,7 +93,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
9393
force_download=force_download,
9494
proxies=proxies,
9595
local_files_only=local_files_only,
96-
token=token,
96+
token=hf_token,
9797
revision=revision,
9898
subfolder=subfolder,
9999
user_agent=user_agent,
@@ -312,7 +312,7 @@ def load_textual_inversion(
312312
local_files_only (`bool`, *optional*, defaults to `False`):
313313
Whether to only load local model weights and configuration files or not. If set to `True`, the model
314314
won't be downloaded from the Hub.
315-
token (`str` or *bool*, *optional*):
315+
hf_token (`str` or *bool*, *optional*):
316316
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
317317
`diffusers-cli login` (stored in `~/.huggingface`) is used.
318318
revision (`str`, *optional*, defaults to `"main"`):

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...loaders import Mochi1LoraLoaderMixin
24-
from ...models.autoencoders import AutoencoderKL
24+
from ...models.autoencoders import AutoencoderKLMochi
2525
from ...models.transformers import MochiTransformer3DModel
2626
from ...schedulers import FlowMatchEulerDiscreteScheduler
2727
from ...utils import (
@@ -151,8 +151,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
151151
Conditional Transformer architecture to denoise the encoded video latents.
152152
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
153153
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
154-
vae ([`AutoencoderKL`]):
155-
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
154+
vae ([`AutoencoderKLMochi`]):
155+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
156156
text_encoder ([`T5EncoderModel`]):
157157
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
158158
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
@@ -171,7 +171,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
171171
def __init__(
172172
self,
173173
scheduler: FlowMatchEulerDiscreteScheduler,
174-
vae: AutoencoderKL,
174+
vae: AutoencoderKLMochi,
175175
text_encoder: T5EncoderModel,
176176
tokenizer: T5TokenizerFast,
177177
transformer: MochiTransformer3DModel,

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import inspect
1717
import re
1818
import urllib.parse as ul
19+
import warnings
1920
from typing import Callable, Dict, List, Optional, Tuple, Union
2021

2122
import torch
@@ -41,6 +42,7 @@
4142
ASPECT_RATIO_1024_BIN,
4243
)
4344
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
45+
from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN
4446
from .pag_utils import PAGMixin
4547

4648

@@ -639,7 +641,7 @@ def __call__(
639641
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
640642
output_type: Optional[str] = "pil",
641643
return_dict: bool = True,
642-
clean_caption: bool = True,
644+
clean_caption: bool = False,
643645
use_resolution_binning: bool = True,
644646
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
645647
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -755,7 +757,9 @@ def __call__(
755757
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
756758

757759
if use_resolution_binning:
758-
if self.transformer.config.sample_size == 64:
760+
if self.transformer.config.sample_size == 128:
761+
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
762+
elif self.transformer.config.sample_size == 64:
759763
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
760764
elif self.transformer.config.sample_size == 32:
761765
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
@@ -912,7 +916,14 @@ def __call__(
912916
image = latents
913917
else:
914918
latents = latents.to(self.vae.dtype)
915-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
919+
try:
920+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
921+
except torch.cuda.OutOfMemoryError as e:
922+
warnings.warn(
923+
f"{e}. \n"
924+
f"Try to use VAE tiling for large images. For example: \n"
925+
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
926+
)
916927
if use_resolution_binning:
917928
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
918929

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import inspect
1717
import re
1818
import urllib.parse as ul
19+
import warnings
1920
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2021

2122
import torch
@@ -953,7 +954,14 @@ def __call__(
953954
image = latents
954955
else:
955956
latents = latents.to(self.vae.dtype)
956-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
957+
try:
958+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
959+
except torch.cuda.OutOfMemoryError as e:
960+
warnings.warn(
961+
f"{e}. \n"
962+
f"Try to use VAE tiling for large images. For example: \n"
963+
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
964+
)
957965
if use_resolution_binning:
958966
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
959967

0 commit comments

Comments
 (0)