Skip to content

Commit a456fb1

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent ab2d71b commit a456fb1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3158,7 +3158,7 @@ def __call__(
31583158
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
31593159
# scaled_dot_product_attention expects attention_mask shape to be
31603160
# (batch, heads, source_length, target_length)
3161-
attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1])
3161+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
31623162
attention_mask = attention_mask.repeat_interleave(hidden_states.shape[1], dim=2)
31633163
if attention_mask.dtype == torch.bool:
31643164
attention_mask = torch.logical_not(attention_mask.bool())

0 commit comments

Comments
 (0)