Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,29 +577,38 @@ def LinearStrengthModel(start, finish, size):
pipeline.set_ip_adapter_scale(ip_strengths)
```
"""
transformer = self.transformer
if not isinstance(scale, list):
scale = [[scale] * transformer.config.num_layers]
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
if len(scale) != transformer.config.num_layers:
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")

from ..pipelines.pipeline_loading_utils import _get_detailed_type, _is_valid_type

scale_type = Union[int, float]
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
num_layers = self.transformer.config.num_layers

# Single value for all layers of all IP-Adapters
if isinstance(scale, scale_type):
scale = [scale for _ in range(num_ip_adapters)]
# List of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
scale = [scale]
# Invalid scale type
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")

scale_configs = scale
if len(scale) != num_ip_adapters:
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")

key_id = 0
for attn_name, attn_processor in transformer.attn_processors.items():
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
attn_processor.scale[i] = scale_config[key_id]
key_id += 1
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
raise ValueError(
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
)

# Scalars are transformed to lists with length num_layers
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]

# Set scales. zip over scale_configs prevents going into single transformer layers
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
attn_processor.scale = scale

def unload_ip_adapter(self):
"""
Expand Down
15 changes: 8 additions & 7 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2778,9 +2778,8 @@ def __call__(

# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = None
# for ip-adapter
# TODO: support for multiple adapters
ip_attn_output = torch.zeros_like(hidden_states)

for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
Expand All @@ -2791,12 +2790,14 @@ def __call__(
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_attn_output = F.scaled_dot_product_attention(
current_ip_hidden_states = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_attn_output = scale * ip_attn_output
ip_attn_output = ip_attn_output.to(ip_query.dtype)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states

return hidden_states, encoder_hidden_states, ip_attn_output
else:
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2583,6 +2583,11 @@ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[
super().__init__()
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)

@property
def num_ip_adapters(self) -> int:
"""Number of IP-Adapters loaded."""
return len(self.image_projection_layers)

def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds = []

Expand Down
28 changes: 19 additions & 9 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,23 +405,28 @@ def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)

for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
):
for single_ip_adapter_image in ip_adapter_image:
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)

image_embeds.append(single_image_embeds[None, :])
else:
if not isinstance(ip_adapter_image_embeds, list):
ip_adapter_image_embeds = [ip_adapter_image_embeds]

if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)

for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)

ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
for single_image_embeds in image_embeds:
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
Expand Down Expand Up @@ -871,11 +876,16 @@ def __call__(
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
zeros_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :]
negative_ip_adapter_image_embeds = [negative_ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters

elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
zeros_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :]
ip_adapter_image_embeds = [ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters

if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
Expand Down