Skip to content

Commit 0dce2cb

Browse files
committed
check tensors on accelerator
1 parent 7dab9e3 commit 0dce2cb

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1418,7 +1418,7 @@ def _flash_varlen_attention_3(
14181418

14191419
@_AttentionBackendRegistry.register(
14201420
AttentionBackendName.AITER,
1421-
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1421+
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
14221422
)
14231423
def _aiter_flash_attention(
14241424
query: torch.Tensor,

0 commit comments

Comments
 (0)