Skip to content

Commit 111603e

Browse files
committed
test
1 parent 1dd6762 commit 111603e

File tree

1 file changed

+16
-88
lines changed

1 file changed

+16
-88
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 16 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,6 +2532,7 @@ def __call__(
25322532
) -> torch.FloatTensor:
25332533
if image_projection is None:
25342534
raise ValueError("image_projection is None")
2535+
print(image_projection, image_projection.shape)
25352536
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
25362537

25372538
# `sample` projections.
@@ -2603,103 +2604,30 @@ def __call__(
26032604
# IP-adapter
26042605
ip_hidden_states = image_projection
26052606

2606-
if ip_adapter_masks is not None:
2607-
if not isinstance(ip_adapter_masks, List):
2608-
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2609-
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
2610-
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
2611-
raise ValueError(
2612-
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
2613-
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
2614-
f"({len(ip_hidden_states)})"
2615-
)
2616-
else:
2617-
for index, (mask, scale, ip_state) in enumerate(
2618-
zip(ip_adapter_masks, self.scale, ip_hidden_states)
2619-
):
2620-
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
2621-
raise ValueError(
2622-
"Each element of the ip_adapter_masks array should be a tensor with shape "
2623-
"[1, num_images_for_ip_adapter, height, width]."
2624-
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
2625-
)
2626-
if mask.shape[1] != ip_state.shape[1]:
2627-
raise ValueError(
2628-
f"Number of masks ({mask.shape[1]}) does not match "
2629-
f"number of ip images ({ip_state.shape[1]}) at index {index}"
2630-
)
2631-
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
2632-
raise ValueError(
2633-
f"Number of masks ({mask.shape[1]}) does not match "
2634-
f"number of scales ({len(scale)}) at index {index}"
2635-
)
2636-
else:
2637-
ip_adapter_masks = [None] * len(self.scale)
2638-
26392607
ip_query = hidden_states_query_proj
26402608
ip_attn_output = None
26412609
# for ip-adapter
26422610
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
26432611
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
26442612
):
2645-
skip = False
2646-
if isinstance(scale, list):
2647-
if all(s == 0 for s in scale):
2648-
skip = True
2649-
elif scale == 0:
2650-
skip = True
2651-
if not skip:
2652-
if mask is not None:
2653-
if not isinstance(scale, list):
2654-
scale = [scale] * mask.shape[1]
2655-
2656-
current_num_images = mask.shape[1]
2657-
for i in range(current_num_images):
2658-
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
2659-
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
2613+
ip_key = to_k_ip(current_ip_hidden_states)
2614+
ip_value = to_v_ip(current_ip_hidden_states)
26602615

2661-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2662-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2616+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2617+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
26632618

2664-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2665-
# TODO: add support for attn.scale when we move to Torch 2.1
2666-
_current_ip_hidden_states = F.scaled_dot_product_attention(
2667-
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2668-
)
2669-
2670-
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
2671-
batch_size, -1, attn.heads * head_dim
2672-
)
2673-
_current_ip_hidden_states = _current_ip_hidden_states.to(ip_query.dtype)
2674-
2675-
mask_downsample = IPAdapterMaskProcessor.downsample(
2676-
mask[:, i, :, :],
2677-
batch_size,
2678-
_current_ip_hidden_states.shape[1],
2679-
_current_ip_hidden_states.shape[2],
2680-
)
2681-
2682-
mask_downsample = mask_downsample.to(dtype=ip_query.dtype, device=ip_query.device)
2683-
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
2684-
else:
2685-
ip_key = to_k_ip(current_ip_hidden_states)
2686-
ip_value = to_v_ip(current_ip_hidden_states)
2687-
2688-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2689-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2690-
2691-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2692-
# TODO: add support for attn.scale when we move to Torch 2.1
2693-
current_ip_hidden_states = F.scaled_dot_product_attention(
2694-
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2695-
)
2696-
2697-
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2698-
batch_size, -1, attn.heads * head_dim
2699-
)
2700-
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2619+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2620+
# TODO: add support for attn.scale when we move to Torch 2.1
2621+
ip_attn_output = F.scaled_dot_product_attention(
2622+
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2623+
)
27012624

2702-
ip_attn_output = scale * current_ip_hidden_states
2625+
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(
2626+
batch_size, -1, attn.heads * head_dim
2627+
)
2628+
ip_attn_output = scale * ip_attn_output
2629+
print(ip_attn_output)
2630+
ip_attn_output = ip_attn_output.to(ip_query.dtype)
27032631

27042632
return hidden_states, encoder_hidden_states, ip_attn_output
27052633
else:

0 commit comments

Comments
 (0)