Skip to content

Commit 60d2673

Browse files
committed
Fixes
1 parent 6a6636c commit 60d2673

File tree

4 files changed

+205
-69
lines changed

4 files changed

+205
-69
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,10 @@ class FluxIPAdapterMixin:
357357
def load_ip_adapter(
358358
self,
359359
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
360-
subfolder: Union[str, List[str]],
361360
weight_name: Union[str, List[str]],
361+
subfolder: Optional[Union[str, List[str]]] = "",
362362
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
363-
image_encoder_subfolder: Optional[str] = None,
363+
image_encoder_subfolder: Optional[str] = "",
364364
**kwargs,
365365
):
366366
"""
@@ -492,6 +492,7 @@ def load_ip_adapter(
492492
".".join(key.split(".")[1:])
493493
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
494494
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
495+
.replace("processor.", "")
495496
)
496497
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
497498
else:
@@ -555,10 +556,22 @@ def set_ip_adapter_scale(self, scale):
555556
```
556557
"""
557558
transformer = self.transformer
559+
if not isinstance(scale, list):
560+
scale = [scale]
561+
562+
scale_configs = scale
558563

559564
for attn_name, attn_processor in transformer.attn_processors.items():
560565
if isinstance(attn_processor, (FluxIPAdapterAttnProcessor2_0)):
561-
attn_processor.scale = scale
566+
if len(scale_configs) != len(attn_processor.scale):
567+
raise ValueError(
568+
f"Cannot assign {len(scale_configs)} scale_configs to "
569+
f"{len(attn_processor.scale)} IP-Adapter."
570+
)
571+
elif len(scale_configs) == 1:
572+
scale_configs = scale_configs * len(attn_processor.scale)
573+
for i, scale_config in enumerate(scale_configs):
574+
attn_processor.scale[i] = scale_config
562575

563576
def unload_ip_adapter(self):
564577
"""

src/diffusers/loaders/transformer_flux.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8484

8585
return image_projection
8686

87-
def _convert_ip_adapter_attn_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
87+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
8888
from ..models.attention_processor import (
8989
FluxIPAdapterAttnProcessor2_0,
9090
)
@@ -110,35 +110,47 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dict, low_cpu_mem_usage=Fa
110110

111111
# set ip-adapter cross-attention processors & load state_dict
112112
attn_procs = {}
113-
key_id = 1
113+
key_id = 0
114114
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
115115
for name in self.attn_processors.keys():
116116
if name.startswith("single_transformer_blocks"):
117-
continue
118-
119-
cross_attention_dim = self.config.joint_attention_dim
120-
hidden_size = self.config.inner_dim
121-
attn_processor_class = FluxIPAdapterAttnProcessor2_0
122-
123-
with init_context():
124-
attn_procs[name] = attn_processor_class(
125-
hidden_size=hidden_size,
126-
cross_attention_dim=cross_attention_dim,
127-
scale=1.0,
128-
)
129-
130-
value_dict = {}
131-
value_dict.update({"to_k_ip.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
132-
value_dict.update({"to_v_ip.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
133-
134-
if not low_cpu_mem_usage:
135-
attn_procs[name].load_state_dict(value_dict)
117+
attn_processor_class = self.attn_processors[name].__class__
118+
attn_procs[name] = attn_processor_class()
136119
else:
137-
device = next(iter(value_dict.values())).device
138-
dtype = next(iter(value_dict.values())).dtype
139-
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
140-
141-
key_id += 1
120+
cross_attention_dim = self.config.joint_attention_dim
121+
hidden_size = self.inner_dim
122+
attn_processor_class = FluxIPAdapterAttnProcessor2_0
123+
num_image_text_embeds = []
124+
for state_dict in state_dicts:
125+
if "proj.weight" in state_dict["image_proj"]:
126+
# IP-Adapter
127+
num_image_text_embeds += [4]
128+
129+
with init_context():
130+
attn_procs[name] = attn_processor_class(
131+
hidden_size=hidden_size,
132+
cross_attention_dim=cross_attention_dim,
133+
scale=1.0,
134+
num_tokens=num_image_text_embeds,
135+
dtype=self.dtype,
136+
device=self.device,
137+
)
138+
139+
value_dict = {}
140+
for i, state_dict in enumerate(state_dicts):
141+
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
142+
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
143+
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
144+
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
145+
146+
if not low_cpu_mem_usage:
147+
attn_procs[name].load_state_dict(value_dict)
148+
else:
149+
device = self.device
150+
dtype = self.dtype
151+
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
152+
153+
key_id += 1
142154

143155
return attn_procs
144156

@@ -160,5 +172,3 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
160172

161173
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
162174
self.config.encoder_hid_dim_type = "ip_image_proj"
163-
164-
self.to(dtype=self.dtype, device=self.device)

src/diffusers/models/attention_processor.py

Lines changed: 133 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def forward(
482482
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
483483

484484
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
485-
quiet_attn_parameters = {"ip_adapter_masks"}
485+
quiet_attn_parameters = {"ip_adapter_masks", "image_projection"}
486486
unused_kwargs = [
487487
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
488488
]
@@ -1893,31 +1893,43 @@ def __call__(
18931893
return hidden_states
18941894

18951895

1896-
class FluxIPAdapterAttnProcessor2_0:
1896+
class FluxIPAdapterAttnProcessor2_0(torch.nn.Module):
18971897
"""Flux Attention processor for IP-Adapter."""
18981898

1899-
def __init__(self, hidden_size: int, cross_attention_dim: int, scale: float = 1.0):
1899+
def __init__(
1900+
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
1901+
):
19001902
super().__init__()
19011903

1902-
r"""
1903-
Args:
1904-
hidden_size (`int`):
1905-
The hidden size of the attention layer.
1906-
cross_attention_dim (`int`):
1907-
The number of channels in the `encoder_hidden_states`.
1908-
scale (`float`, defaults to 1.0):
1909-
the weight scale of image prompt.
1910-
"""
1911-
19121904
if not hasattr(F, "scaled_dot_product_attention"):
19131905
raise ImportError(
1914-
"FluxIPAdapterAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1906+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
19151907
)
19161908

1909+
self.hidden_size = hidden_size
1910+
self.cross_attention_dim = cross_attention_dim
1911+
1912+
if not isinstance(num_tokens, (tuple, list)):
1913+
num_tokens = [num_tokens]
1914+
1915+
if not isinstance(scale, list):
1916+
scale = [scale] * len(num_tokens)
1917+
if len(scale) != len(num_tokens):
1918+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
19171919
self.scale = scale
19181920

1919-
self.to_k_ip = nn.Linear(cross_attention_dim, hidden_size)
1920-
self.to_v_ip = nn.Linear(cross_attention_dim, hidden_size)
1921+
self.to_k_ip = nn.ModuleList(
1922+
[
1923+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
1924+
for _ in range(len(num_tokens))
1925+
]
1926+
)
1927+
self.to_v_ip = nn.ModuleList(
1928+
[
1929+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
1930+
for _ in range(len(num_tokens))
1931+
]
1932+
)
19211933

19221934
def __call__(
19231935
self,
@@ -1926,24 +1938,27 @@ def __call__(
19261938
encoder_hidden_states: torch.FloatTensor = None,
19271939
attention_mask: Optional[torch.FloatTensor] = None,
19281940
image_rotary_emb: Optional[torch.Tensor] = None,
1929-
image_projection: Optional[torch.Tensor] = None,
1941+
image_projection: Optional[List[torch.Tensor]] = None,
1942+
ip_adapter_masks: Optional[torch.Tensor] = None,
19301943
) -> torch.FloatTensor:
1944+
if image_projection is None:
1945+
raise ValueError("image_projection is None")
19311946
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
19321947

19331948
# `sample` projections.
1934-
query = attn.to_q(hidden_states)
1949+
hidden_states_query_proj = attn.to_q(hidden_states)
19351950
key = attn.to_k(hidden_states)
19361951
value = attn.to_v(hidden_states)
19371952

19381953
inner_dim = key.shape[-1]
19391954
head_dim = inner_dim // attn.heads
19401955

1941-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1956+
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
19421957
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
19431958
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
19441959

19451960
if attn.norm_q is not None:
1946-
query = attn.norm_q(query)
1961+
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
19471962
if attn.norm_k is not None:
19481963
key = attn.norm_k(key)
19491964

@@ -1970,7 +1985,7 @@ def __call__(
19701985
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
19711986

19721987
# attention
1973-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1988+
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
19741989
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
19751990
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
19761991

@@ -1997,19 +2012,104 @@ def __call__(
19972012
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
19982013

19992014
# IP-adapter
2000-
ip_key = self.to_k_ip(image_projection)
2001-
ip_value = self.to_v_ip(image_projection)
2015+
ip_hidden_states = image_projection
2016+
2017+
if ip_adapter_masks is not None:
2018+
if not isinstance(ip_adapter_masks, List):
2019+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2020+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
2021+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
2022+
raise ValueError(
2023+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
2024+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
2025+
f"({len(ip_hidden_states)})"
2026+
)
2027+
else:
2028+
for index, (mask, scale, ip_state) in enumerate(
2029+
zip(ip_adapter_masks, self.scale, ip_hidden_states)
2030+
):
2031+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
2032+
raise ValueError(
2033+
"Each element of the ip_adapter_masks array should be a tensor with shape "
2034+
"[1, num_images_for_ip_adapter, height, width]."
2035+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
2036+
)
2037+
if mask.shape[1] != ip_state.shape[1]:
2038+
raise ValueError(
2039+
f"Number of masks ({mask.shape[1]}) does not match "
2040+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
2041+
)
2042+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
2043+
raise ValueError(
2044+
f"Number of masks ({mask.shape[1]}) does not match "
2045+
f"number of scales ({len(scale)}) at index {index}"
2046+
)
2047+
else:
2048+
ip_adapter_masks = [None] * len(self.scale)
2049+
2050+
ip_query = hidden_states_query_proj
2051+
# for ip-adapter
2052+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
2053+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
2054+
):
2055+
skip = False
2056+
if isinstance(scale, list):
2057+
if all(s == 0 for s in scale):
2058+
skip = True
2059+
elif scale == 0:
2060+
skip = True
2061+
if not skip:
2062+
if mask is not None:
2063+
if not isinstance(scale, list):
2064+
scale = [scale] * mask.shape[1]
2065+
2066+
current_num_images = mask.shape[1]
2067+
for i in range(current_num_images):
2068+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
2069+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
2070+
2071+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2072+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2073+
2074+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2075+
# TODO: add support for attn.scale when we move to Torch 2.1
2076+
_current_ip_hidden_states = F.scaled_dot_product_attention(
2077+
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2078+
)
2079+
2080+
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
2081+
batch_size, -1, attn.heads * head_dim
2082+
)
2083+
_current_ip_hidden_states = _current_ip_hidden_states.to(ip_query.dtype)
2084+
2085+
mask_downsample = IPAdapterMaskProcessor.downsample(
2086+
mask[:, i, :, :],
2087+
batch_size,
2088+
_current_ip_hidden_states.shape[1],
2089+
_current_ip_hidden_states.shape[2],
2090+
)
2091+
2092+
mask_downsample = mask_downsample.to(dtype=ip_query.dtype, device=ip_query.device)
2093+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
2094+
else:
2095+
ip_key = to_k_ip(current_ip_hidden_states)
2096+
ip_value = to_v_ip(current_ip_hidden_states)
20022097

2003-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2004-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2005-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2006-
# TODO: add support for attn.scale when we move to Torch 2.1
2007-
ip_hidden_states = F.scaled_dot_product_attention(
2008-
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2009-
)
2010-
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2011-
ip_hidden_states = ip_hidden_states.to(query.dtype)
2012-
hidden_states = hidden_states + self.scale * ip_hidden_states
2098+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2099+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2100+
2101+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2102+
# TODO: add support for attn.scale when we move to Torch 2.1
2103+
current_ip_hidden_states = F.scaled_dot_product_attention(
2104+
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2105+
)
2106+
2107+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2108+
batch_size, -1, attn.heads * head_dim
2109+
)
2110+
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2111+
2112+
hidden_states = hidden_states + scale * current_ip_hidden_states
20132113

20142114
return hidden_states, encoder_hidden_states
20152115
else:

0 commit comments

Comments
 (0)