Skip to content

Commit ab2d71b

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent d61d570 commit ab2d71b

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3158,10 +3158,12 @@ def __call__(
31583158
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
31593159
# scaled_dot_product_attention expects attention_mask shape to be
31603160
# (batch, heads, source_length, target_length)
3161-
attn_mask = attention_mask[0]
3162-
seq_len = hidden_states.shape[1]
3163-
attention_mask = attn_mask.repeat_interleave(seq_len * batch_size, dim=0)
31643161
attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1])
3162+
attention_mask = attention_mask.repeat_interleave(hidden_states.shape[1], dim=2)
3163+
if attention_mask.dtype == torch.bool:
3164+
attention_mask = torch.logical_not(attention_mask.bool())
3165+
else:
3166+
attention_mask = attention_mask.bool()
31653167

31663168
if attention_mask.dtype != torch.uint8:
31673169
if attention_mask.dtype == torch.bool:

0 commit comments

Comments
 (0)