Skip to content

Commit 3add6de

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent 7364276 commit 3add6de

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3159,7 +3159,7 @@ def __call__(
31593159
# scaled_dot_product_attention expects attention_mask shape to be
31603160
# (batch, heads, source_length, target_length)
31613161
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3162-
attention_mask = attention_mask.repeat_interleave(hidden_states.shape[1], dim=2)
3162+
attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
31633163
if attention_mask.dtype == torch.bool:
31643164
attention_mask = torch.logical_not(attention_mask.bool())
31653165
else:

0 commit comments

Comments
 (0)