Skip to content

Commit b8f7fe6

Browse files
committed
handle ip adapter params correctly
1 parent ff21b7f commit b8f7fe6

File tree

4 files changed

+29
-156
lines changed

4 files changed

+29
-156
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@
4040
from ..models.attention_processor import (
4141
AttnProcessor,
4242
AttnProcessor2_0,
43-
FluxAttnProcessor2_0,
44-
FluxIPAdapterJointAttnProcessor2_0,
4543
IPAdapterAttnProcessor,
4644
IPAdapterAttnProcessor2_0,
4745
IPAdapterXFormersAttnProcessor,
@@ -867,6 +865,9 @@ def unload_ip_adapter(self):
867865
>>> ...
868866
```
869867
"""
868+
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
869+
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
870+
870871
# remove CLIP image encoder
871872
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
872873
self.image_encoder = None
@@ -886,9 +887,9 @@ def unload_ip_adapter(self):
886887
# restore original Transformer attention processors layers
887888
attn_procs = {}
888889
for name, value in self.transformer.attn_processors.items():
889-
attn_processor_class = FluxAttnProcessor2_0()
890+
attn_processor_class = FluxAttnProcessor()
890891
attn_procs[name] = (
891-
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
892+
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
892893
)
893894
self.transformer.set_attn_processor(attn_procs)
894895

src/diffusers/loaders/transformer_flux.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8787
return image_projection
8888

8989
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
90-
from ..models.attention_processor import (
91-
FluxIPAdapterJointAttnProcessor2_0,
92-
)
90+
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
9391

9492
if low_cpu_mem_usage:
9593
if is_accelerate_available():
@@ -121,7 +119,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
121119
else:
122120
cross_attention_dim = self.config.joint_attention_dim
123121
hidden_size = self.inner_dim
124-
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
122+
attn_processor_class = FluxIPAdapterAttnProcessor
125123
num_image_text_embeds = []
126124
for state_dict in state_dicts:
127125
if "proj.weight" in state_dict["image_proj"]:

src/diffusers/models/attention_processor.py

Lines changed: 10 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -2501,152 +2501,6 @@ def __call__(
25012501
return hidden_states
25022502

25032503

2504-
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
2505-
"""Flux Attention processor for IP-Adapter."""
2506-
2507-
def __init__(
2508-
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
2509-
):
2510-
super().__init__()
2511-
2512-
if not hasattr(F, "scaled_dot_product_attention"):
2513-
raise ImportError(
2514-
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2515-
)
2516-
2517-
self.hidden_size = hidden_size
2518-
self.cross_attention_dim = cross_attention_dim
2519-
2520-
if not isinstance(num_tokens, (tuple, list)):
2521-
num_tokens = [num_tokens]
2522-
2523-
if not isinstance(scale, list):
2524-
scale = [scale] * len(num_tokens)
2525-
if len(scale) != len(num_tokens):
2526-
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
2527-
self.scale = scale
2528-
2529-
self.to_k_ip = nn.ModuleList(
2530-
[
2531-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2532-
for _ in range(len(num_tokens))
2533-
]
2534-
)
2535-
self.to_v_ip = nn.ModuleList(
2536-
[
2537-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2538-
for _ in range(len(num_tokens))
2539-
]
2540-
)
2541-
2542-
def __call__(
2543-
self,
2544-
attn: Attention,
2545-
hidden_states: torch.FloatTensor,
2546-
encoder_hidden_states: torch.FloatTensor = None,
2547-
attention_mask: Optional[torch.FloatTensor] = None,
2548-
image_rotary_emb: Optional[torch.Tensor] = None,
2549-
ip_hidden_states: Optional[List[torch.Tensor]] = None,
2550-
ip_adapter_masks: Optional[torch.Tensor] = None,
2551-
) -> torch.FloatTensor:
2552-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2553-
2554-
# `sample` projections.
2555-
hidden_states_query_proj = attn.to_q(hidden_states)
2556-
key = attn.to_k(hidden_states)
2557-
value = attn.to_v(hidden_states)
2558-
2559-
inner_dim = key.shape[-1]
2560-
head_dim = inner_dim // attn.heads
2561-
2562-
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2563-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2564-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2565-
2566-
if attn.norm_q is not None:
2567-
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
2568-
if attn.norm_k is not None:
2569-
key = attn.norm_k(key)
2570-
2571-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2572-
if encoder_hidden_states is not None:
2573-
# `context` projections.
2574-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2575-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2576-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2577-
2578-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2579-
batch_size, -1, attn.heads, head_dim
2580-
).transpose(1, 2)
2581-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2582-
batch_size, -1, attn.heads, head_dim
2583-
).transpose(1, 2)
2584-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2585-
batch_size, -1, attn.heads, head_dim
2586-
).transpose(1, 2)
2587-
2588-
if attn.norm_added_q is not None:
2589-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2590-
if attn.norm_added_k is not None:
2591-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2592-
2593-
# attention
2594-
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
2595-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2596-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2597-
2598-
if image_rotary_emb is not None:
2599-
from .embeddings import apply_rotary_emb
2600-
2601-
query = apply_rotary_emb(query, image_rotary_emb)
2602-
key = apply_rotary_emb(key, image_rotary_emb)
2603-
2604-
hidden_states = F.scaled_dot_product_attention(
2605-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2606-
)
2607-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2608-
hidden_states = hidden_states.to(query.dtype)
2609-
2610-
if encoder_hidden_states is not None:
2611-
encoder_hidden_states, hidden_states = (
2612-
hidden_states[:, : encoder_hidden_states.shape[1]],
2613-
hidden_states[:, encoder_hidden_states.shape[1] :],
2614-
)
2615-
2616-
# linear proj
2617-
hidden_states = attn.to_out[0](hidden_states)
2618-
# dropout
2619-
hidden_states = attn.to_out[1](hidden_states)
2620-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2621-
2622-
# IP-adapter
2623-
ip_query = hidden_states_query_proj
2624-
ip_attn_output = torch.zeros_like(hidden_states)
2625-
2626-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2627-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2628-
):
2629-
ip_key = to_k_ip(current_ip_hidden_states)
2630-
ip_value = to_v_ip(current_ip_hidden_states)
2631-
2632-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2633-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2634-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2635-
# TODO: add support for attn.scale when we move to Torch 2.1
2636-
current_ip_hidden_states = F.scaled_dot_product_attention(
2637-
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2638-
)
2639-
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2640-
batch_size, -1, attn.heads * head_dim
2641-
)
2642-
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2643-
ip_attn_output += scale * current_ip_hidden_states
2644-
2645-
return hidden_states, encoder_hidden_states, ip_attn_output
2646-
else:
2647-
return hidden_states
2648-
2649-
26502504
class CogVideoXAttnProcessor2_0:
26512505
r"""
26522506
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -6019,6 +5873,16 @@ def __new__(cls, *args, **kwargs):
60195873
return FluxAttnProcessor(*args, **kwargs)
60205874

60215875

5876+
class FluxIPAdapterJointAttnProcessor2_0:
5877+
def __new__(cls, *args, **kwargs):
5878+
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
5879+
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
5880+
5881+
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
5882+
5883+
return FluxIPAdapterAttnProcessor(*args, **kwargs)
5884+
5885+
60225886
ADDED_KV_ATTENTION_PROCESSORS = (
60235887
AttnAddedKVProcessor,
60245888
SlicedAttnAddedKVProcessor,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
import inspect
1616
from typing import Any, Dict, List, Optional, Tuple, Union
1717

1818
import numpy as np
@@ -241,7 +241,9 @@ def __call__(
241241
query = apply_rotary_emb(query, image_rotary_emb)
242242
key = apply_rotary_emb(key, image_rotary_emb)
243243

244-
hidden_states = torch.nn.functional(query, key, value, dropout_p=0.0, is_causal=False)
244+
hidden_states = torch.nn.functional.scaled_dot_product_attention(
245+
query, key, value, dropout_p=0.0, is_causal=False
246+
)
245247
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
246248
hidden_states = hidden_states.to(query.dtype)
247249

@@ -354,6 +356,14 @@ def forward(
354356
image_rotary_emb: Optional[torch.Tensor] = None,
355357
**kwargs,
356358
) -> torch.Tensor:
359+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
360+
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
361+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
362+
if len(unused_kwargs) > 0:
363+
logger.warning(
364+
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
365+
)
366+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
357367
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
358368

359369

0 commit comments

Comments
 (0)