Skip to content

Commit c7a0b9e

Browse files
authored
Merge branch 'main' into 8bit-lora-loading
2 parents 55b4137 + 2432f80 commit c7a0b9e

File tree

8 files changed

+205
-12
lines changed

8 files changed

+205
-12
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..utils import (
2222
USE_PEFT_BACKEND,
2323
deprecate,
24+
get_submodule_by_name,
2425
is_peft_available,
2526
is_peft_version,
2627
is_torch_version,
@@ -1981,10 +1982,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
19811982
in_features = state_dict[lora_A_weight_name].shape[1]
19821983
out_features = state_dict[lora_B_weight_name].shape[0]
19831984

1985+
# Model maybe loaded with different quantization schemes which may flatten the params.
1986+
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
1987+
# preserve weight shape.
1988+
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
1989+
19841990
# This means there's no need for an expansion in the params, so we simply skip.
1985-
if tuple(module_weight.shape) == (out_features, in_features):
1991+
if tuple(module_weight_shape) == (out_features, in_features):
19861992
continue
19871993

1994+
# TODO (sayakpaul): We still need to consider if the module we're expanding is
1995+
# quantized and handle it accordingly if that is the case.
19881996
module_out_features, module_in_features = module_weight.shape
19891997
debug_message = ""
19901998
if in_features > module_in_features:
@@ -2080,13 +2088,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20802088
base_weight_param = transformer_state_dict[base_param_name]
20812089
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
20822090

2083-
if base_weight_param.shape[1] > lora_A_param.shape[1]:
2091+
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
2092+
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
2093+
2094+
if base_module_shape[1] > lora_A_param.shape[1]:
20842095
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
20852096
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
20862097
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
20872098
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
20882099
expanded_module_names.add(k)
2089-
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
2100+
elif base_module_shape[1] < lora_A_param.shape[1]:
20902101
raise NotImplementedError(
20912102
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
20922103
)
@@ -2098,6 +2109,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20982109

20992110
return lora_state_dict
21002111

2112+
@staticmethod
2113+
def _calculate_module_shape(
2114+
model: "torch.nn.Module",
2115+
base_module: "torch.nn.Linear" = None,
2116+
base_weight_param_name: str = None,
2117+
) -> "torch.Size":
2118+
def _get_weight_shape(weight: torch.Tensor):
2119+
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
2120+
2121+
if base_module is not None:
2122+
return _get_weight_shape(base_module.weight)
2123+
elif base_weight_param_name is not None:
2124+
if not base_weight_param_name.endswith(".weight"):
2125+
raise ValueError(
2126+
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
2127+
)
2128+
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
2129+
submodule = get_submodule_by_name(model, module_path)
2130+
return _get_weight_shape(submodule.weight)
2131+
2132+
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
2133+
21012134

21022135
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
21032136
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 130 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,21 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
from typing import Callable, Dict, List, Optional, Union
16+
from typing import Any, Callable, Dict, List, Optional, Union
1717

1818
import torch
1919
from transformers import (
20+
BaseImageProcessor,
2021
CLIPTextModelWithProjection,
2122
CLIPTokenizer,
23+
PreTrainedModel,
2224
T5EncoderModel,
2325
T5TokenizerFast,
2426
)
2527

2628
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2729
from ...image_processor import PipelineImageInput, VaeImageProcessor
28-
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
30+
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
2931
from ...models.autoencoders import AutoencoderKL
3032
from ...models.transformers import SD3Transformer2DModel
3133
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -162,7 +164,7 @@ def retrieve_timesteps(
162164
return timesteps, num_inference_steps
163165

164166

165-
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
167+
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
166168
r"""
167169
Args:
168170
transformer ([`SD3Transformer2DModel`]):
@@ -194,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
194196
tokenizer_3 (`T5TokenizerFast`):
195197
Tokenizer of class
196198
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
199+
image_encoder (`PreTrainedModel`, *optional*):
200+
Pre-trained Vision Model for IP Adapter.
201+
feature_extractor (`BaseImageProcessor`, *optional*):
202+
Image processor for IP Adapter.
197203
"""
198204

199-
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
200-
_optional_components = []
205+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
206+
_optional_components = ["image_encoder", "feature_extractor"]
201207
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
202208

203209
def __init__(
@@ -211,6 +217,8 @@ def __init__(
211217
tokenizer_2: CLIPTokenizer,
212218
text_encoder_3: T5EncoderModel,
213219
tokenizer_3: T5TokenizerFast,
220+
image_encoder: PreTrainedModel = None,
221+
feature_extractor: BaseImageProcessor = None,
214222
):
215223
super().__init__()
216224

@@ -224,6 +232,8 @@ def __init__(
224232
tokenizer_3=tokenizer_3,
225233
transformer=transformer,
226234
scheduler=scheduler,
235+
image_encoder=image_encoder,
236+
feature_extractor=feature_extractor,
227237
)
228238
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
229239
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
@@ -818,6 +828,10 @@ def clip_skip(self):
818828
def do_classifier_free_guidance(self):
819829
return self._guidance_scale > 1
820830

831+
@property
832+
def joint_attention_kwargs(self):
833+
return self._joint_attention_kwargs
834+
821835
@property
822836
def num_timesteps(self):
823837
return self._num_timesteps
@@ -826,6 +840,84 @@ def num_timesteps(self):
826840
def interrupt(self):
827841
return self._interrupt
828842

843+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
844+
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
845+
"""Encodes the given image into a feature representation using a pre-trained image encoder.
846+
847+
Args:
848+
image (`PipelineImageInput`):
849+
Input image to be encoded.
850+
device: (`torch.device`):
851+
Torch device.
852+
853+
Returns:
854+
`torch.Tensor`: The encoded image feature representation.
855+
"""
856+
if not isinstance(image, torch.Tensor):
857+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
858+
859+
image = image.to(device=device, dtype=self.dtype)
860+
861+
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
862+
863+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
864+
def prepare_ip_adapter_image_embeds(
865+
self,
866+
ip_adapter_image: Optional[PipelineImageInput] = None,
867+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
868+
device: Optional[torch.device] = None,
869+
num_images_per_prompt: int = 1,
870+
do_classifier_free_guidance: bool = True,
871+
) -> torch.Tensor:
872+
"""Prepares image embeddings for use in the IP-Adapter.
873+
874+
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
875+
876+
Args:
877+
ip_adapter_image (`PipelineImageInput`, *optional*):
878+
The input image to extract features from for IP-Adapter.
879+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
880+
Precomputed image embeddings.
881+
device: (`torch.device`, *optional*):
882+
Torch device.
883+
num_images_per_prompt (`int`, defaults to 1):
884+
Number of images that should be generated per prompt.
885+
do_classifier_free_guidance (`bool`, defaults to True):
886+
Whether to use classifier free guidance or not.
887+
"""
888+
device = device or self._execution_device
889+
890+
if ip_adapter_image_embeds is not None:
891+
if do_classifier_free_guidance:
892+
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
893+
else:
894+
single_image_embeds = ip_adapter_image_embeds
895+
elif ip_adapter_image is not None:
896+
single_image_embeds = self.encode_image(ip_adapter_image, device)
897+
if do_classifier_free_guidance:
898+
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
899+
else:
900+
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
901+
902+
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
903+
904+
if do_classifier_free_guidance:
905+
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
906+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
907+
908+
return image_embeds.to(device=device)
909+
910+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
911+
def enable_sequential_cpu_offload(self, *args, **kwargs):
912+
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
913+
logger.warning(
914+
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
915+
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
916+
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
917+
)
918+
919+
super().enable_sequential_cpu_offload(*args, **kwargs)
920+
829921
@torch.no_grad()
830922
@replace_example_docstring(EXAMPLE_DOC_STRING)
831923
def __call__(
@@ -853,8 +945,11 @@ def __call__(
853945
negative_prompt_embeds: Optional[torch.Tensor] = None,
854946
pooled_prompt_embeds: Optional[torch.Tensor] = None,
855947
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
948+
ip_adapter_image: Optional[PipelineImageInput] = None,
949+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
856950
output_type: Optional[str] = "pil",
857951
return_dict: bool = True,
952+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
858953
clip_skip: Optional[int] = None,
859954
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
860955
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -890,9 +985,9 @@ def __call__(
890985
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
891986
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
892987
latents tensor will ge generated by `mask_image`.
893-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
988+
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
894989
The height in pixels of the generated image. This is set to 1024 by default for the best results.
895-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
990+
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
896991
The width in pixels of the generated image. This is set to 1024 by default for the best results.
897992
padding_mask_crop (`int`, *optional*, defaults to `None`):
898993
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
@@ -953,12 +1048,22 @@ def __call__(
9531048
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
9541049
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
9551050
input argument.
1051+
ip_adapter_image (`PipelineImageInput`, *optional*):
1052+
Optional image input to work with IP Adapters.
1053+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
1054+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
1055+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
1056+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
9561057
output_type (`str`, *optional*, defaults to `"pil"`):
9571058
The output format of the generate image. Choose between
9581059
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
9591060
return_dict (`bool`, *optional*, defaults to `True`):
9601061
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
9611062
a plain tuple.
1063+
joint_attention_kwargs (`dict`, *optional*):
1064+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1065+
`self.processor` in
1066+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
9621067
callback_on_step_end (`Callable`, *optional*):
9631068
A function that calls at the end of each denoising steps during the inference. The function is called
9641069
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -1006,6 +1111,7 @@ def __call__(
10061111

10071112
self._guidance_scale = guidance_scale
10081113
self._clip_skip = clip_skip
1114+
self._joint_attention_kwargs = joint_attention_kwargs
10091115
self._interrupt = False
10101116

10111117
# 2. Define call parameters
@@ -1160,7 +1266,22 @@ def __call__(
11601266
f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
11611267
)
11621268

1163-
# 7. Denoising loop
1269+
# 7. Prepare image embeddings
1270+
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1271+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1272+
ip_adapter_image,
1273+
ip_adapter_image_embeds,
1274+
device,
1275+
batch_size * num_images_per_prompt,
1276+
self.do_classifier_free_guidance,
1277+
)
1278+
1279+
if self.joint_attention_kwargs is None:
1280+
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1281+
else:
1282+
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1283+
1284+
# 8. Denoising loop
11641285
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
11651286
self._num_timesteps = len(timesteps)
11661287
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1181,6 +1302,7 @@ def __call__(
11811302
timestep=timestep,
11821303
encoder_hidden_states=prompt_embeds,
11831304
pooled_projections=pooled_prompt_embeds,
1305+
joint_attention_kwargs=self.joint_attention_kwargs,
11841306
return_dict=False,
11851307
)[0]
11861308

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
is_xformers_available,
102102
requires_backends,
103103
)
104-
from .loading_utils import get_module_from_name, load_image, load_video
104+
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
105105
from .logging import get_logger
106106
from .outputs import BaseOutput
107107
from .peft_utils import (

src/diffusers/utils/loading_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,15 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
148148
module = new_module
149149
tensor_name = splits[-1]
150150
return module, tensor_name
151+
152+
153+
def get_submodule_by_name(root_module, module_path: str):
154+
current = root_module
155+
parts = module_path.split(".")
156+
for part in parts:
157+
if part.isdigit():
158+
idx = int(part)
159+
current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
160+
else:
161+
current = getattr(current, part)
162+
return current

tests/models/transformers/test_models_transformer_cogvideox.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
3333
model_class = CogVideoXTransformer3DModel
3434
main_input_name = "hidden_states"
3535
uses_custom_attn_processor = True
36+
model_split_percents = [0.7, 0.7, 0.8]
3637

3738
@property
3839
def dummy_input(self):

tests/models/transformers/test_models_transformer_cogview3plus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
3333
model_class = CogView3PlusTransformer2DModel
3434
main_input_name = "hidden_states"
3535
uses_custom_attn_processor = True
36+
model_split_percents = [0.7, 0.6, 0.6]
3637

3738
@property
3839
def dummy_input(self):

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def get_dummy_components(self):
106106
"tokenizer_3": tokenizer_3,
107107
"transformer": transformer,
108108
"vae": vae,
109+
"image_encoder": None,
110+
"feature_extractor": None,
109111
}
110112

111113
def get_dummy_inputs(self, device, seed=0):

0 commit comments

Comments
 (0)