Skip to content

Commit 9fb8e0d

Browse files
committed
Fixes
1 parent 5af8c9a commit 9fb8e0d

File tree

4 files changed

+205
-69
lines changed

4 files changed

+205
-69
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,10 @@ class FluxIPAdapterMixin:
357357
def load_ip_adapter(
358358
self,
359359
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
360-
subfolder: Union[str, List[str]],
361360
weight_name: Union[str, List[str]],
361+
subfolder: Optional[Union[str, List[str]]] = "",
362362
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
363-
image_encoder_subfolder: Optional[str] = None,
363+
image_encoder_subfolder: Optional[str] = "",
364364
**kwargs,
365365
):
366366
"""
@@ -492,6 +492,7 @@ def load_ip_adapter(
492492
".".join(key.split(".")[1:])
493493
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
494494
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
495+
.replace("processor.", "")
495496
)
496497
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
497498
else:
@@ -555,10 +556,22 @@ def set_ip_adapter_scale(self, scale):
555556
```
556557
"""
557558
transformer = self.transformer
559+
if not isinstance(scale, list):
560+
scale = [scale]
561+
562+
scale_configs = scale
558563

559564
for attn_name, attn_processor in transformer.attn_processors.items():
560565
if isinstance(attn_processor, (FluxIPAdapterAttnProcessor2_0)):
561-
attn_processor.scale = scale
566+
if len(scale_configs) != len(attn_processor.scale):
567+
raise ValueError(
568+
f"Cannot assign {len(scale_configs)} scale_configs to "
569+
f"{len(attn_processor.scale)} IP-Adapter."
570+
)
571+
elif len(scale_configs) == 1:
572+
scale_configs = scale_configs * len(attn_processor.scale)
573+
for i, scale_config in enumerate(scale_configs):
574+
attn_processor.scale[i] = scale_config
562575

563576
def unload_ip_adapter(self):
564577
"""

src/diffusers/loaders/transformer_flux.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8484

8585
return image_projection
8686

87-
def _convert_ip_adapter_attn_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
87+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
8888
from ..models.attention_processor import (
8989
FluxIPAdapterAttnProcessor2_0,
9090
)
@@ -110,35 +110,47 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dict, low_cpu_mem_usage=Fa
110110

111111
# set ip-adapter cross-attention processors & load state_dict
112112
attn_procs = {}
113-
key_id = 1
113+
key_id = 0
114114
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
115115
for name in self.attn_processors.keys():
116116
if name.startswith("single_transformer_blocks"):
117-
continue
118-
119-
cross_attention_dim = self.config.joint_attention_dim
120-
hidden_size = self.config.inner_dim
121-
attn_processor_class = FluxIPAdapterAttnProcessor2_0
122-
123-
with init_context():
124-
attn_procs[name] = attn_processor_class(
125-
hidden_size=hidden_size,
126-
cross_attention_dim=cross_attention_dim,
127-
scale=1.0,
128-
)
129-
130-
value_dict = {}
131-
value_dict.update({"to_k_ip.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
132-
value_dict.update({"to_v_ip.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
133-
134-
if not low_cpu_mem_usage:
135-
attn_procs[name].load_state_dict(value_dict)
117+
attn_processor_class = self.attn_processors[name].__class__
118+
attn_procs[name] = attn_processor_class()
136119
else:
137-
device = next(iter(value_dict.values())).device
138-
dtype = next(iter(value_dict.values())).dtype
139-
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
140-
141-
key_id += 1
120+
cross_attention_dim = self.config.joint_attention_dim
121+
hidden_size = self.inner_dim
122+
attn_processor_class = FluxIPAdapterAttnProcessor2_0
123+
num_image_text_embeds = []
124+
for state_dict in state_dicts:
125+
if "proj.weight" in state_dict["image_proj"]:
126+
# IP-Adapter
127+
num_image_text_embeds += [4]
128+
129+
with init_context():
130+
attn_procs[name] = attn_processor_class(
131+
hidden_size=hidden_size,
132+
cross_attention_dim=cross_attention_dim,
133+
scale=1.0,
134+
num_tokens=num_image_text_embeds,
135+
dtype=self.dtype,
136+
device=self.device,
137+
)
138+
139+
value_dict = {}
140+
for i, state_dict in enumerate(state_dicts):
141+
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
142+
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
143+
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
144+
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
145+
146+
if not low_cpu_mem_usage:
147+
attn_procs[name].load_state_dict(value_dict)
148+
else:
149+
device = self.device
150+
dtype = self.dtype
151+
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
152+
153+
key_id += 1
142154

143155
return attn_procs
144156

@@ -160,5 +172,3 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
160172

161173
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
162174
self.config.encoder_hid_dim_type = "ip_image_proj"
163-
164-
self.to(dtype=self.dtype, device=self.device)

src/diffusers/models/attention_processor.py

Lines changed: 133 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def forward(
482482
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
483483

484484
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
485-
quiet_attn_parameters = {"ip_adapter_masks"}
485+
quiet_attn_parameters = {"ip_adapter_masks", "image_projection"}
486486
unused_kwargs = [
487487
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
488488
]
@@ -1987,31 +1987,43 @@ def __call__(
19871987
return hidden_states
19881988

19891989

1990-
class FluxIPAdapterAttnProcessor2_0:
1990+
class FluxIPAdapterAttnProcessor2_0(torch.nn.Module):
19911991
"""Flux Attention processor for IP-Adapter."""
19921992

1993-
def __init__(self, hidden_size: int, cross_attention_dim: int, scale: float = 1.0):
1993+
def __init__(
1994+
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
1995+
):
19941996
super().__init__()
19951997

1996-
r"""
1997-
Args:
1998-
hidden_size (`int`):
1999-
The hidden size of the attention layer.
2000-
cross_attention_dim (`int`):
2001-
The number of channels in the `encoder_hidden_states`.
2002-
scale (`float`, defaults to 1.0):
2003-
the weight scale of image prompt.
2004-
"""
2005-
20061998
if not hasattr(F, "scaled_dot_product_attention"):
20071999
raise ImportError(
2008-
"FluxIPAdapterAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2000+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
20092001
)
20102002

2003+
self.hidden_size = hidden_size
2004+
self.cross_attention_dim = cross_attention_dim
2005+
2006+
if not isinstance(num_tokens, (tuple, list)):
2007+
num_tokens = [num_tokens]
2008+
2009+
if not isinstance(scale, list):
2010+
scale = [scale] * len(num_tokens)
2011+
if len(scale) != len(num_tokens):
2012+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
20112013
self.scale = scale
20122014

2013-
self.to_k_ip = nn.Linear(cross_attention_dim, hidden_size)
2014-
self.to_v_ip = nn.Linear(cross_attention_dim, hidden_size)
2015+
self.to_k_ip = nn.ModuleList(
2016+
[
2017+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2018+
for _ in range(len(num_tokens))
2019+
]
2020+
)
2021+
self.to_v_ip = nn.ModuleList(
2022+
[
2023+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2024+
for _ in range(len(num_tokens))
2025+
]
2026+
)
20152027

20162028
def __call__(
20172029
self,
@@ -2020,24 +2032,27 @@ def __call__(
20202032
encoder_hidden_states: torch.FloatTensor = None,
20212033
attention_mask: Optional[torch.FloatTensor] = None,
20222034
image_rotary_emb: Optional[torch.Tensor] = None,
2023-
image_projection: Optional[torch.Tensor] = None,
2035+
image_projection: Optional[List[torch.Tensor]] = None,
2036+
ip_adapter_masks: Optional[torch.Tensor] = None,
20242037
) -> torch.FloatTensor:
2038+
if image_projection is None:
2039+
raise ValueError("image_projection is None")
20252040
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
20262041

20272042
# `sample` projections.
2028-
query = attn.to_q(hidden_states)
2043+
hidden_states_query_proj = attn.to_q(hidden_states)
20292044
key = attn.to_k(hidden_states)
20302045
value = attn.to_v(hidden_states)
20312046

20322047
inner_dim = key.shape[-1]
20332048
head_dim = inner_dim // attn.heads
20342049

2035-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2050+
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
20362051
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
20372052
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
20382053

20392054
if attn.norm_q is not None:
2040-
query = attn.norm_q(query)
2055+
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
20412056
if attn.norm_k is not None:
20422057
key = attn.norm_k(key)
20432058

@@ -2064,7 +2079,7 @@ def __call__(
20642079
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
20652080

20662081
# attention
2067-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2082+
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
20682083
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
20692084
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
20702085

@@ -2091,19 +2106,104 @@ def __call__(
20912106
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
20922107

20932108
# IP-adapter
2094-
ip_key = self.to_k_ip(image_projection)
2095-
ip_value = self.to_v_ip(image_projection)
2109+
ip_hidden_states = image_projection
2110+
2111+
if ip_adapter_masks is not None:
2112+
if not isinstance(ip_adapter_masks, List):
2113+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2114+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
2115+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
2116+
raise ValueError(
2117+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
2118+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
2119+
f"({len(ip_hidden_states)})"
2120+
)
2121+
else:
2122+
for index, (mask, scale, ip_state) in enumerate(
2123+
zip(ip_adapter_masks, self.scale, ip_hidden_states)
2124+
):
2125+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
2126+
raise ValueError(
2127+
"Each element of the ip_adapter_masks array should be a tensor with shape "
2128+
"[1, num_images_for_ip_adapter, height, width]."
2129+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
2130+
)
2131+
if mask.shape[1] != ip_state.shape[1]:
2132+
raise ValueError(
2133+
f"Number of masks ({mask.shape[1]}) does not match "
2134+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
2135+
)
2136+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
2137+
raise ValueError(
2138+
f"Number of masks ({mask.shape[1]}) does not match "
2139+
f"number of scales ({len(scale)}) at index {index}"
2140+
)
2141+
else:
2142+
ip_adapter_masks = [None] * len(self.scale)
2143+
2144+
ip_query = hidden_states_query_proj
2145+
# for ip-adapter
2146+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
2147+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
2148+
):
2149+
skip = False
2150+
if isinstance(scale, list):
2151+
if all(s == 0 for s in scale):
2152+
skip = True
2153+
elif scale == 0:
2154+
skip = True
2155+
if not skip:
2156+
if mask is not None:
2157+
if not isinstance(scale, list):
2158+
scale = [scale] * mask.shape[1]
2159+
2160+
current_num_images = mask.shape[1]
2161+
for i in range(current_num_images):
2162+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
2163+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
2164+
2165+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2166+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2167+
2168+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2169+
# TODO: add support for attn.scale when we move to Torch 2.1
2170+
_current_ip_hidden_states = F.scaled_dot_product_attention(
2171+
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2172+
)
2173+
2174+
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
2175+
batch_size, -1, attn.heads * head_dim
2176+
)
2177+
_current_ip_hidden_states = _current_ip_hidden_states.to(ip_query.dtype)
2178+
2179+
mask_downsample = IPAdapterMaskProcessor.downsample(
2180+
mask[:, i, :, :],
2181+
batch_size,
2182+
_current_ip_hidden_states.shape[1],
2183+
_current_ip_hidden_states.shape[2],
2184+
)
2185+
2186+
mask_downsample = mask_downsample.to(dtype=ip_query.dtype, device=ip_query.device)
2187+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
2188+
else:
2189+
ip_key = to_k_ip(current_ip_hidden_states)
2190+
ip_value = to_v_ip(current_ip_hidden_states)
20962191

2097-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2098-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2099-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2100-
# TODO: add support for attn.scale when we move to Torch 2.1
2101-
ip_hidden_states = F.scaled_dot_product_attention(
2102-
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2103-
)
2104-
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2105-
ip_hidden_states = ip_hidden_states.to(query.dtype)
2106-
hidden_states = hidden_states + self.scale * ip_hidden_states
2192+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2193+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2194+
2195+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2196+
# TODO: add support for attn.scale when we move to Torch 2.1
2197+
current_ip_hidden_states = F.scaled_dot_product_attention(
2198+
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2199+
)
2200+
2201+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2202+
batch_size, -1, attn.heads * head_dim
2203+
)
2204+
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2205+
2206+
hidden_states = hidden_states + scale * current_ip_hidden_states
21072207

21082208
return hidden_states, encoder_hidden_states
21092209
else:

0 commit comments

Comments
 (0)