Skip to content

Commit a71cfce

Browse files
committed
Initial implementation of Flux multi IP-Adapter
1 parent d75ea3c commit a71cfce

File tree

6 files changed

+131
-34
lines changed

6 files changed

+131
-34
lines changed

examples/community/pipeline_flux_semantic_guidance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,9 @@ def prepare_ip_adapter_image_embeds(
537537
if not isinstance(ip_adapter_image, list):
538538
ip_adapter_image = [ip_adapter_image]
539539

540-
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
540+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
541541
raise ValueError(
542-
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
542+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
543543
)
544544

545545
for single_ip_adapter_image, image_proj_layer in zip(

src/diffusers/loaders/ip_adapter.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -577,29 +577,38 @@ def LinearStrengthModel(start, finish, size):
577577
pipeline.set_ip_adapter_scale(ip_strengths)
578578
```
579579
"""
580-
transformer = self.transformer
581-
if not isinstance(scale, list):
582-
scale = [[scale] * transformer.config.num_layers]
583-
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
584-
if len(scale) != transformer.config.num_layers:
585-
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
580+
581+
from ..pipelines.pipeline_loading_utils import _get_detailed_type, _is_valid_type
582+
583+
scale_type = Union[int, float]
584+
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
585+
num_layers = self.transformer.config.num_layers
586+
587+
# Single value for all layers of all IP-Adapters
588+
if isinstance(scale, scale_type):
589+
scale = [scale for _ in range(num_ip_adapters)]
590+
# List of per-layer scales for a single IP-Adapter
591+
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
586592
scale = [scale]
593+
# Invalid scale type
594+
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
595+
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
587596

588-
scale_configs = scale
597+
if len(scale) != num_ip_adapters:
598+
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
589599

590-
key_id = 0
591-
for attn_name, attn_processor in transformer.attn_processors.items():
592-
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
593-
if len(scale_configs) != len(attn_processor.scale):
594-
raise ValueError(
595-
f"Cannot assign {len(scale_configs)} scale_configs to "
596-
f"{len(attn_processor.scale)} IP-Adapter."
597-
)
598-
elif len(scale_configs) == 1:
599-
scale_configs = scale_configs * len(attn_processor.scale)
600-
for i, scale_config in enumerate(scale_configs):
601-
attn_processor.scale[i] = scale_config[key_id]
602-
key_id += 1
600+
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
601+
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
602+
raise ValueError(
603+
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
604+
)
605+
606+
# Scalars are transformed to lists with length num_layers
607+
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
608+
609+
# Set scales. zip over scale_configs prevents going into single transformer layers
610+
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
611+
attn_processor.scale = scale
603612

604613
def unload_ip_adapter(self):
605614
"""

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2778,9 +2778,8 @@ def __call__(
27782778

27792779
# IP-adapter
27802780
ip_query = hidden_states_query_proj
2781-
ip_attn_output = None
2782-
# for ip-adapter
2783-
# TODO: support for multiple adapters
2781+
ip_attn_output = torch.zeros_like(hidden_states)
2782+
27842783
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
27852784
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
27862785
):
@@ -2791,12 +2790,14 @@ def __call__(
27912790
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
27922791
# the output of sdp = (batch, num_heads, seq_len, head_dim)
27932792
# TODO: add support for attn.scale when we move to Torch 2.1
2794-
ip_attn_output = F.scaled_dot_product_attention(
2793+
current_ip_hidden_states = F.scaled_dot_product_attention(
27952794
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
27962795
)
2797-
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2798-
ip_attn_output = scale * ip_attn_output
2799-
ip_attn_output = ip_attn_output.to(ip_query.dtype)
2796+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2797+
batch_size, -1, attn.heads * head_dim
2798+
)
2799+
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2800+
ip_attn_output += scale * current_ip_hidden_states
28002801

28012802
return hidden_states, encoder_hidden_states, ip_attn_output
28022803
else:

src/diffusers/models/embeddings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,11 @@ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[
25832583
super().__init__()
25842584
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
25852585

2586+
@property
2587+
def num_ip_adapters(self) -> int:
2588+
"""Number of IP-Adapters loaded."""
2589+
return len(self.image_projection_layers)
2590+
25862591
def forward(self, image_embeds: List[torch.Tensor]):
25872592
projected_image_embeds = []
25882593

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,9 @@ def prepare_ip_adapter_image_embeds(
405405
if not isinstance(ip_adapter_image, list):
406406
ip_adapter_image = [ip_adapter_image]
407407

408-
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
408+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
409409
raise ValueError(
410-
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
410+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
411411
)
412412

413413
for single_ip_adapter_image, image_proj_layer in zip(
@@ -868,14 +868,23 @@ def __call__(
868868
else:
869869
guidance = None
870870

871+
# TODO: Clarify this section
871872
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
872873
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
873874
):
874-
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
875+
negative_ip_adapter_image = (
876+
[np.zeros((width, height, 3), dtype=np.uint8) for _ in range(len(ip_adapter_image))]
877+
if isinstance(ip_adapter_image, list)
878+
else np.zeros((width, height, 3), dtype=np.uint8)
879+
)
875880
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
876881
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
877882
):
878-
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
883+
ip_adapter_image = (
884+
[np.zeros((width, height, 3), dtype=np.uint8) for _ in range(len(negative_ip_adapter_image))]
885+
if isinstance(negative_ip_adapter_image, list)
886+
else np.zeros((width, height, 3), dtype=np.uint8)
887+
)
879888

880889
if self.joint_attention_kwargs is None:
881890
self._joint_attention_kwargs = {}

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
import warnings
1919
from pathlib import Path
20-
from typing import Any, Callable, Dict, List, Optional, Union
20+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
2121

2222
import requests
2323
import torch
@@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
10591059
break
10601060
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
10611061
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
1062+
1063+
1064+
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
1065+
"""
1066+
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
1067+
the correct type as well.
1068+
"""
1069+
if not isinstance(class_or_tuple, tuple):
1070+
class_or_tuple = (class_or_tuple,)
1071+
1072+
# Unpack unions
1073+
unpacked_class_or_tuple = []
1074+
for t in class_or_tuple:
1075+
if get_origin(t) is Union:
1076+
unpacked_class_or_tuple.extend(get_args(t))
1077+
else:
1078+
unpacked_class_or_tuple.append(t)
1079+
class_or_tuple = tuple(unpacked_class_or_tuple)
1080+
1081+
if Any in class_or_tuple:
1082+
return True
1083+
1084+
obj_type = type(obj)
1085+
# Classes with obj's type
1086+
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
1087+
1088+
# Singular types (e.g. int, ControlNet, ...)
1089+
# Untyped collections (e.g. List, but not List[int])
1090+
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
1091+
if () in elem_class_or_tuple:
1092+
return True
1093+
# Typed lists or sets
1094+
elif obj_type in (list, set):
1095+
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
1096+
# Typed tuples
1097+
elif obj_type is tuple:
1098+
return any(
1099+
# Tuples with any length and single type (e.g. Tuple[int, ...])
1100+
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
1101+
or
1102+
# Tuples with fixed length and any types (e.g. Tuple[int, str])
1103+
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
1104+
for t in elem_class_or_tuple
1105+
)
1106+
# Typed dicts
1107+
elif obj_type is dict:
1108+
return any(
1109+
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
1110+
for kt, vt in elem_class_or_tuple
1111+
)
1112+
1113+
else:
1114+
return False
1115+
1116+
1117+
def _get_detailed_type(obj: Any) -> Type:
1118+
"""
1119+
Gets a detailed type for an object, including nested types for collections.
1120+
"""
1121+
obj_type = type(obj)
1122+
1123+
if obj_type in (list, set):
1124+
obj_origin_type = List if obj_type is list else Set
1125+
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
1126+
return obj_origin_type[elems_type]
1127+
elif obj_type is tuple:
1128+
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
1129+
elif obj_type is dict:
1130+
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
1131+
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
1132+
return Dict[keys_type, values_type]
1133+
else:
1134+
return obj_type

0 commit comments

Comments
 (0)