diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 6a05aac215c6..e123f4c19398 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -955,12 +955,13 @@ def _native_npu_attention( dropout_p: float = 0.0, scale: Optional[float] = None, ) -> torch.Tensor: - return npu_fusion_attention( + query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + out = npu_fusion_attention( query, key, value, - query.size(2), # num_heads - input_layout="BSND", + query.size(1), # num_heads + input_layout="BNSD", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, pre_tockens=65536, @@ -969,6 +970,8 @@ def _native_npu_attention( sync=False, inner_precise=0, )[0] + out = out.transpose(1, 2).contiguous() + return out # Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853