Skip to content

Commit 233cfbe

Browse files
committed
test
1 parent 2ee946f commit 233cfbe

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,15 +2604,15 @@ def __call__(
26042604
ip_hidden_states = image_projection
26052605

26062606
ip_query = hidden_states_query_proj
2607-
ip_attn_output = None
2607+
ip_attn_outputs = []
26082608
# for ip-adapter
26092609
# TODO: fix for multiple
2610-
# NOTE: run zeros image embed at the same time?
26112610
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
26122611
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
26132612
):
2614-
ip_key = to_k_ip(current_ip_hidden_states)
2615-
ip_value = to_v_ip(current_ip_hidden_states)
2613+
positive_ip, negative_ip = current_ip_hidden_states
2614+
ip_key = to_k_ip(positive_ip)
2615+
ip_value = to_v_ip(positive_ip)
26162616

26172617
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
26182618
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@@ -2624,8 +2624,24 @@ def __call__(
26242624
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
26252625
ip_attn_output = scale * ip_attn_output
26262626
ip_attn_output = ip_attn_output.to(ip_query.dtype)
2627+
ip_attn_outputs.append(ip_attn_output)
26272628

2628-
return hidden_states, encoder_hidden_states, ip_attn_output
2629+
ip_key = to_k_ip(negative_ip)
2630+
ip_value = to_v_ip(negative_ip)
2631+
2632+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2633+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2634+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2635+
# TODO: add support for attn.scale when we move to Torch 2.1
2636+
ip_attn_output = F.scaled_dot_product_attention(
2637+
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2638+
)
2639+
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2640+
ip_attn_output = scale * ip_attn_output
2641+
ip_attn_output = ip_attn_output.to(ip_query.dtype)
2642+
ip_attn_outputs.append(ip_attn_output)
2643+
2644+
return hidden_states, encoder_hidden_states, ip_attn_outputs
26292645
else:
26302646
return hidden_states
26312647

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def forward(
187187
if len(attention_outputs) == 2:
188188
attn_output, context_attn_output = attention_outputs
189189
elif len(attention_outputs) == 3:
190-
attn_output, context_attn_output, ip_attn_output = attention_outputs
190+
attn_output, context_attn_output, ip_attn_outputs = attention_outputs
191191

192192
# Process attention outputs for the `hidden_states`.
193193
attn_output = gate_msa.unsqueeze(1) * attn_output
@@ -201,7 +201,8 @@ def forward(
201201

202202
hidden_states = hidden_states + ff_output
203203
if len(attention_outputs) == 3:
204-
hidden_states = hidden_states + ip_attn_output
204+
positive_ip_attn, negative_ip_attn = ip_attn_outputs
205+
hidden_states = hidden_states + positive_ip_attn + negative_ip_attn
205206

206207
# Process attention outputs for the `encoder_hidden_states`.
207208

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,11 @@ def encode_image(self, image, device, num_images_per_prompt):
401401
return image_embeds
402402

403403
def prepare_ip_adapter_image_embeds(
404-
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
404+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, height, width, dtype
405405
):
406406
image_embeds = []
407+
negative_embeds = []
408+
negative_image = np.zeros((width, height, 3), dtype=np.uint8)
407409
if ip_adapter_image_embeds is None:
408410
if not isinstance(ip_adapter_image, list):
409411
ip_adapter_image = [ip_adapter_image]
@@ -417,19 +419,27 @@ def prepare_ip_adapter_image_embeds(
417419
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
418420
):
419421
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
422+
negative_image_embeds = self.encode_image(negative_image, device, 1)
420423

421424
image_embeds.append(single_image_embeds[None, :])
422425
image_embeds = self.transformer.encoder_hid_proj(image_embeds)
426+
negative_embeds.append(negative_image_embeds[None, :])
427+
negative_embeds = self.transformer.encoder_hid_proj(negative_embeds)
423428
else:
424429
for single_image_embeds in ip_adapter_image_embeds:
425430
image_embeds = self.transformer.encoder_hid_proj(single_image_embeds)
426431
image_embeds.append(single_image_embeds)
432+
negative_image_embeds = self.encode_image(negative_image, device, 1)
433+
negative_embeds.append(negative_image_embeds[None, :])
434+
negative_embeds = self.transformer.encoder_hid_proj(negative_embeds)
427435

428436
ip_adapter_image_embeds = []
429-
for i, single_image_embeds in enumerate(image_embeds):
437+
for i, (single_image_embeds, negative_image_embed) in enumerate(zip(image_embeds, negative_embeds)):
430438
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
431-
single_image_embeds = single_image_embeds.to(device=device)
432-
ip_adapter_image_embeds.append(single_image_embeds)
439+
single_image_embeds = single_image_embeds.to(device=device, dtype=dtype)
440+
negative_image_embed = torch.cat([negative_image_embed] * num_images_per_prompt, dim=0)
441+
negative_image_embed = negative_image_embed.to(device=device, dtype=dtype)
442+
ip_adapter_image_embeds.append((single_image_embeds, negative_image_embed))
433443

434444
return ip_adapter_image_embeds
435445

@@ -794,6 +804,9 @@ def __call__(
794804
ip_adapter_image_embeds,
795805
device,
796806
batch_size * num_images_per_prompt,
807+
height,
808+
width,
809+
latents.dtype,
797810
)
798811
if self.joint_attention_kwargs is None:
799812
self._joint_attention_kwargs = {}

0 commit comments

Comments
 (0)