Skip to content

Commit de8909a

Browse files
committed
Added new SD3IPAdapterMixin loader
1 parent 0ef36dd commit de8909a

File tree

4 files changed

+291
-195
lines changed

4 files changed

+291
-195
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def text_encoder_attn_modules(text_encoder):
7171
"Mochi1LoraLoaderMixin",
7272
]
7373
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
74-
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
74+
_import_structure["ip_adapter"] = [
75+
"IPAdapterMixin",
76+
"SD3IPAdapterMixin",
77+
]
7578

7679
_import_structure["peft"] = ["PeftAdapterMixin"]
7780

@@ -83,7 +86,10 @@ def text_encoder_attn_modules(text_encoder):
8386
from .utils import AttnProcsLayers
8487

8588
if is_transformers_available():
86-
from .ip_adapter import IPAdapterMixin
89+
from .ip_adapter import (
90+
IPAdapterMixin,
91+
SD3IPAdapterMixin,
92+
)
8793
from .lora_pipeline import (
8894
AmusedLoraLoaderMixin,
8995
CogVideoXLoraLoaderMixin,

src/diffusers/loaders/ip_adapter.py

Lines changed: 224 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,23 @@
3333

3434

3535
if is_transformers_available():
36-
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
37-
38-
from ..models.attention_processor import (
39-
AttnProcessor,
40-
AttnProcessor2_0,
41-
IPAdapterAttnProcessor,
42-
IPAdapterAttnProcessor2_0,
43-
IPAdapterXFormersAttnProcessor,
36+
from transformers import (
37+
CLIPImageProcessor,
38+
CLIPVisionModelWithProjection,
39+
SiglipImageProcessor,
40+
SiglipVisionModel
4441
)
4542

43+
from ..models.attention_processor import (
44+
AttnProcessor,
45+
AttnProcessor2_0,
46+
JointAttnProcessor2_0,
47+
IPAdapterAttnProcessor,
48+
IPAdapterAttnProcessor2_0,
49+
IPAdapterXFormersAttnProcessor,
50+
IPAdapterJointAttnProcessor2_0,
51+
)
52+
4653
logger = logging.get_logger(__name__)
4754

4855

@@ -348,3 +355,212 @@ def unload_ip_adapter(self):
348355
else value.__class__()
349356
)
350357
self.unet.set_attn_processor(attn_procs)
358+
359+
360+
class SD3IPAdapterMixin:
361+
"""Mixin for handling StableDiffusion 3 IP Adapters."""
362+
363+
@validate_hf_hub_args
364+
def load_ip_adapter(
365+
self,
366+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
367+
subfolder: str,
368+
weight_name: str,
369+
image_encoder_folder: Optional[str] = "image_encoder",
370+
**kwargs,
371+
):
372+
"""
373+
Parameters:
374+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
375+
Can be either:
376+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
377+
the Hub.
378+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
379+
with [`ModelMixin.save_pretrained`].
380+
- A [torch state
381+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
382+
subfolder (`str`):
383+
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
384+
list is passed, it should have the same length as `weight_name`.
385+
weight_name (`str`):
386+
The name of the weight file to load. If a list is passed, it should have the same length as
387+
`subfolder`.
388+
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
389+
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
390+
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
391+
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
392+
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
393+
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
394+
`image_encoder_folder="different_subfolder/image_encoder"`.
395+
cache_dir (`Union[str, os.PathLike]`, *optional*):
396+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
397+
is not used.
398+
force_download (`bool`, *optional*, defaults to `False`):
399+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
400+
cached versions if they exist.
401+
proxies (`Dict[str, str]`, *optional*):
402+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
403+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
404+
local_files_only (`bool`, *optional*, defaults to `False`):
405+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
406+
won't be downloaded from the Hub.
407+
token (`str` or *bool*, *optional*):
408+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
409+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
410+
revision (`str`, *optional*, defaults to `"main"`):
411+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
412+
allowed by Git.
413+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
414+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
415+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
416+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
417+
argument to `True` will raise an error.
418+
"""
419+
# Load the main state dict first
420+
cache_dir = kwargs.pop("cache_dir", None)
421+
force_download = kwargs.pop("force_download", False)
422+
proxies = kwargs.pop("proxies", None)
423+
local_files_only = kwargs.pop("local_files_only", None)
424+
token = kwargs.pop("token", None)
425+
revision = kwargs.pop("revision", None)
426+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
427+
428+
if low_cpu_mem_usage and not is_accelerate_available():
429+
low_cpu_mem_usage = False
430+
logger.warning(
431+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
432+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
433+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
434+
" install accelerate\n```\n."
435+
)
436+
437+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
438+
raise NotImplementedError(
439+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
440+
" `low_cpu_mem_usage=False`."
441+
)
442+
443+
user_agent = {
444+
"file_type": "attn_procs_weights",
445+
"framework": "pytorch",
446+
}
447+
448+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
449+
model_file = _get_model_file(
450+
pretrained_model_name_or_path_or_dict,
451+
weights_name=weight_name,
452+
cache_dir=cache_dir,
453+
force_download=force_download,
454+
proxies=proxies,
455+
local_files_only=local_files_only,
456+
token=token,
457+
revision=revision,
458+
subfolder=subfolder,
459+
user_agent=user_agent,
460+
)
461+
if weight_name.endswith(".safetensors"):
462+
state_dict = {"image_proj": {}, "ip_adapter": {}}
463+
with safe_open(model_file, framework="pt", device="cpu") as f:
464+
for key in f.keys():
465+
if key.startswith("image_proj."):
466+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
467+
elif key.startswith("ip_adapter."):
468+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
469+
else:
470+
state_dict = load_state_dict(model_file)
471+
else:
472+
state_dict = pretrained_model_name_or_path_or_dict
473+
474+
keys = list(state_dict.keys())
475+
if "image_proj" not in keys and "ip_adapter" not in keys:
476+
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
477+
478+
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
479+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
480+
if image_encoder_folder is not None:
481+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
482+
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
483+
if image_encoder_folder.count("/") == 0:
484+
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
485+
else:
486+
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
487+
488+
# Commons args for loading image encoder and image processor
489+
args = dict(
490+
pretrained_model_name_or_path_or_dict,
491+
subfolder=image_encoder_subfolder,
492+
low_cpu_mem_usage=low_cpu_mem_usage,
493+
cache_dir=cache_dir,
494+
local_files_only=local_files_only,
495+
)
496+
497+
self.register_modules(
498+
feature_extractor = SiglipImageProcessor.from_pretrained(**args).to(self.device, dtype=self.dtype),
499+
image_encoder = SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype),
500+
)
501+
else:
502+
raise ValueError(
503+
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
504+
)
505+
else:
506+
logger.warning(
507+
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
508+
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
509+
)
510+
511+
# Load IP-Adapter into transformer
512+
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
513+
514+
def set_ip_adapter_scale(self, scale: float):
515+
"""
516+
Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image prompt, and 0.0
517+
only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they
518+
may not be as aligned with the image prompt.
519+
520+
Example:
521+
522+
```python
523+
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
524+
>>> pipeline.set_ip_adapter_scale(0.6)
525+
>>> ...
526+
```
527+
"""
528+
for attn_processor in self.transformer.attn_processors.values():
529+
if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0):
530+
attn_processor.scale = scale
531+
532+
def unload_ip_adapter(self):
533+
"""
534+
Unloads the IP Adapter weights.
535+
536+
Example:
537+
538+
```python
539+
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
540+
>>> pipeline.unload_ip_adapter()
541+
>>> ...
542+
```
543+
"""
544+
# Remove image encoder
545+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
546+
self.image_encoder = None
547+
self.register_to_config(image_encoder=None)
548+
549+
# Remove feature extractor
550+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
551+
self.feature_extractor = None
552+
self.register_to_config(feature_extractor=None)
553+
554+
# Remove image projection
555+
self.transformer.image_proj = None
556+
557+
# Restore original attention processors layers
558+
attn_procs = {
559+
name: (
560+
JointAttnProcessor2_0()
561+
if isinstance(value, IPAdapterJointAttnProcessor2_0)
562+
else value.__class__()
563+
)
564+
for name, value in self.transformer.attn_processors.items()
565+
}
566+
self.transformer.set_attn_processor(attn_procs)

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@
2727
AttentionProcessor,
2828
FusedJointAttnProcessor2_0,
2929
JointAttnProcessor2_0,
30+
IPAdapterJointAttnProcessor2_0,
3031
)
31-
from ...models.modeling_utils import ModelMixin
32+
from ...models.modeling_utils import ModelMixin, load_model_dict_into_meta
3233
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
3334
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
3435
from ...utils.torch_utils import maybe_allow_in_graph
35-
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
36+
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed, TimePerceiverResampler
3637
from ..modeling_outputs import Transformer2DModelOutput
3738

3839

@@ -332,6 +333,59 @@ def _set_gradient_checkpointing(self, module, value=False):
332333
if hasattr(module, "gradient_checkpointing"):
333334
module.gradient_checkpointing = value
334335

336+
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool):
337+
# IP-Adapter cross attention parameters
338+
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
339+
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
340+
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
341+
342+
# Dict where key is transformer layer index, value is attention processor's state dict
343+
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
344+
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
345+
for key, weights in state_dict["ip_adapter"].items():
346+
idx, name = key.split(".", maxsplit=1)
347+
layer_state_dict[int(idx)][name] = weights
348+
349+
# Create IP-Adapter attention processor
350+
attn_procs = {}
351+
for idx, name in enumerate(self.attn_processors.keys()):
352+
attn_procs[name] = IPAdapterJointAttnProcessor2_0(
353+
hidden_size=hidden_size,
354+
ip_hidden_states_dim=ip_hidden_states_dim,
355+
head_dim=self.config.attention_head_dim,
356+
timesteps_emb_dim=timesteps_emb_dim,
357+
).to(self.device, dtype=self.dtype)
358+
359+
if not low_cpu_mem_usage:
360+
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
361+
else:
362+
load_model_dict_into_meta(attn_procs[name], layer_state_dict, device=self.device, dtype=self.dtype)
363+
364+
self.set_attn_processor(attn_procs)
365+
366+
# Image projetion parameters
367+
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
368+
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
369+
hidden_dim = state_dict["image_proj"]["latents"].shape[2]
370+
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
371+
num_queries = state_dict["image_proj"]["latents"].shape[1]
372+
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
373+
374+
# Image projection
375+
self.image_proj = TimePerceiverResampler(
376+
embed_dim=embed_dim,
377+
output_dim=output_dim,
378+
hidden_dim=hidden_dim,
379+
heads=heads,
380+
num_queries=num_queries,
381+
timestep_in_dim=timestep_in_dim
382+
).to(device=self.device, dtype=self.dtype)
383+
384+
if not low_cpu_mem_usage:
385+
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
386+
else:
387+
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
388+
335389
def forward(
336390
self,
337391
hidden_states: torch.FloatTensor,

0 commit comments

Comments
 (0)