Skip to content

Commit 41a6e86

Browse files
dxqbgithub-actions[bot]sayakpaul
authored
Check for attention mask in backends that don't support it (#12892)
* check attention mask * Apply style fixes * bugfix --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <[email protected]>
1 parent 9b5a244 commit 41a6e86

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,13 +1420,17 @@ def _flash_attention(
14201420
query: torch.Tensor,
14211421
key: torch.Tensor,
14221422
value: torch.Tensor,
1423+
attn_mask: Optional[torch.Tensor] = None,
14231424
dropout_p: float = 0.0,
14241425
is_causal: bool = False,
14251426
scale: Optional[float] = None,
14261427
return_lse: bool = False,
14271428
_parallel_config: Optional["ParallelConfig"] = None,
14281429
) -> torch.Tensor:
14291430
lse = None
1431+
if attn_mask is not None:
1432+
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
1433+
14301434
if _parallel_config is None:
14311435
out = flash_attn_func(
14321436
q=query,
@@ -1469,13 +1473,17 @@ def _flash_attention_hub(
14691473
query: torch.Tensor,
14701474
key: torch.Tensor,
14711475
value: torch.Tensor,
1476+
attn_mask: Optional[torch.Tensor] = None,
14721477
dropout_p: float = 0.0,
14731478
is_causal: bool = False,
14741479
scale: Optional[float] = None,
14751480
return_lse: bool = False,
14761481
_parallel_config: Optional["ParallelConfig"] = None,
14771482
) -> torch.Tensor:
14781483
lse = None
1484+
if attn_mask is not None:
1485+
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
1486+
14791487
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
14801488
out = func(
14811489
q=query,
@@ -1612,11 +1620,15 @@ def _flash_attention_3(
16121620
query: torch.Tensor,
16131621
key: torch.Tensor,
16141622
value: torch.Tensor,
1623+
attn_mask: Optional[torch.Tensor] = None,
16151624
scale: Optional[float] = None,
16161625
is_causal: bool = False,
16171626
return_lse: bool = False,
16181627
_parallel_config: Optional["ParallelConfig"] = None,
16191628
) -> torch.Tensor:
1629+
if attn_mask is not None:
1630+
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
1631+
16201632
out, lse = _wrapped_flash_attn_3(
16211633
q=query,
16221634
k=key,
@@ -1636,6 +1648,7 @@ def _flash_attention_3_hub(
16361648
query: torch.Tensor,
16371649
key: torch.Tensor,
16381650
value: torch.Tensor,
1651+
attn_mask: Optional[torch.Tensor] = None,
16391652
scale: Optional[float] = None,
16401653
is_causal: bool = False,
16411654
window_size: Tuple[int, int] = (-1, -1),
@@ -1646,6 +1659,8 @@ def _flash_attention_3_hub(
16461659
) -> torch.Tensor:
16471660
if _parallel_config:
16481661
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
1662+
if attn_mask is not None:
1663+
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
16491664

16501665
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
16511666
out = func(
@@ -1785,12 +1800,16 @@ def _aiter_flash_attention(
17851800
query: torch.Tensor,
17861801
key: torch.Tensor,
17871802
value: torch.Tensor,
1803+
attn_mask: Optional[torch.Tensor] = None,
17881804
dropout_p: float = 0.0,
17891805
is_causal: bool = False,
17901806
scale: Optional[float] = None,
17911807
return_lse: bool = False,
17921808
_parallel_config: Optional["ParallelConfig"] = None,
17931809
) -> torch.Tensor:
1810+
if attn_mask is not None:
1811+
raise ValueError("`attn_mask` is not supported for aiter attention")
1812+
17941813
if not return_lse and torch.is_grad_enabled():
17951814
# aiter requires return_lse=True by assertion when gradients are enabled.
17961815
out, lse, *_ = aiter_flash_attn_func(
@@ -2028,13 +2047,17 @@ def _native_flash_attention(
20282047
query: torch.Tensor,
20292048
key: torch.Tensor,
20302049
value: torch.Tensor,
2050+
attn_mask: Optional[torch.Tensor] = None,
20312051
dropout_p: float = 0.0,
20322052
is_causal: bool = False,
20332053
scale: Optional[float] = None,
20342054
enable_gqa: bool = False,
20352055
return_lse: bool = False,
20362056
_parallel_config: Optional["ParallelConfig"] = None,
20372057
) -> torch.Tensor:
2058+
if attn_mask is not None:
2059+
raise ValueError("`attn_mask` is not supported for aiter attention")
2060+
20382061
lse = None
20392062
if _parallel_config is None and not return_lse:
20402063
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2113,11 +2136,14 @@ def _native_npu_attention(
21132136
query: torch.Tensor,
21142137
key: torch.Tensor,
21152138
value: torch.Tensor,
2139+
attn_mask: Optional[torch.Tensor] = None,
21162140
dropout_p: float = 0.0,
21172141
scale: Optional[float] = None,
21182142
return_lse: bool = False,
21192143
_parallel_config: Optional["ParallelConfig"] = None,
21202144
) -> torch.Tensor:
2145+
if attn_mask is not None:
2146+
raise ValueError("`attn_mask` is not supported for NPU attention")
21212147
if return_lse:
21222148
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
21232149
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
@@ -2148,10 +2174,13 @@ def _native_xla_attention(
21482174
query: torch.Tensor,
21492175
key: torch.Tensor,
21502176
value: torch.Tensor,
2177+
attn_mask: Optional[torch.Tensor] = None,
21512178
is_causal: bool = False,
21522179
return_lse: bool = False,
21532180
_parallel_config: Optional["ParallelConfig"] = None,
21542181
) -> torch.Tensor:
2182+
if attn_mask is not None:
2183+
raise ValueError("`attn_mask` is not supported for XLA attention")
21552184
if return_lse:
21562185
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
21572186
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2175,11 +2204,14 @@ def _sage_attention(
21752204
query: torch.Tensor,
21762205
key: torch.Tensor,
21772206
value: torch.Tensor,
2207+
attn_mask: Optional[torch.Tensor] = None,
21782208
is_causal: bool = False,
21792209
scale: Optional[float] = None,
21802210
return_lse: bool = False,
21812211
_parallel_config: Optional["ParallelConfig"] = None,
21822212
) -> torch.Tensor:
2213+
if attn_mask is not None:
2214+
raise ValueError("`attn_mask` is not supported for sage attention")
21832215
lse = None
21842216
if _parallel_config is None:
21852217
out = sageattn(
@@ -2223,11 +2255,14 @@ def _sage_attention_hub(
22232255
query: torch.Tensor,
22242256
key: torch.Tensor,
22252257
value: torch.Tensor,
2258+
attn_mask: Optional[torch.Tensor] = None,
22262259
is_causal: bool = False,
22272260
scale: Optional[float] = None,
22282261
return_lse: bool = False,
22292262
_parallel_config: Optional["ParallelConfig"] = None,
22302263
) -> torch.Tensor:
2264+
if attn_mask is not None:
2265+
raise ValueError("`attn_mask` is not supported for sage attention")
22312266
lse = None
22322267
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
22332268
if _parallel_config is None:
@@ -2309,11 +2344,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
23092344
query: torch.Tensor,
23102345
key: torch.Tensor,
23112346
value: torch.Tensor,
2347+
attn_mask: Optional[torch.Tensor] = None,
23122348
is_causal: bool = False,
23132349
scale: Optional[float] = None,
23142350
return_lse: bool = False,
23152351
_parallel_config: Optional["ParallelConfig"] = None,
23162352
) -> torch.Tensor:
2353+
if attn_mask is not None:
2354+
raise ValueError("`attn_mask` is not supported for sage attention")
23172355
return sageattn_qk_int8_pv_fp8_cuda(
23182356
q=query,
23192357
k=key,
@@ -2333,11 +2371,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
23332371
query: torch.Tensor,
23342372
key: torch.Tensor,
23352373
value: torch.Tensor,
2374+
attn_mask: Optional[torch.Tensor] = None,
23362375
is_causal: bool = False,
23372376
scale: Optional[float] = None,
23382377
return_lse: bool = False,
23392378
_parallel_config: Optional["ParallelConfig"] = None,
23402379
) -> torch.Tensor:
2380+
if attn_mask is not None:
2381+
raise ValueError("`attn_mask` is not supported for sage attention")
23412382
return sageattn_qk_int8_pv_fp8_cuda_sm90(
23422383
q=query,
23432384
k=key,
@@ -2357,11 +2398,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
23572398
query: torch.Tensor,
23582399
key: torch.Tensor,
23592400
value: torch.Tensor,
2401+
attn_mask: Optional[torch.Tensor] = None,
23602402
is_causal: bool = False,
23612403
scale: Optional[float] = None,
23622404
return_lse: bool = False,
23632405
_parallel_config: Optional["ParallelConfig"] = None,
23642406
) -> torch.Tensor:
2407+
if attn_mask is not None:
2408+
raise ValueError("`attn_mask` is not supported for sage attention")
23652409
return sageattn_qk_int8_pv_fp16_cuda(
23662410
q=query,
23672411
k=key,
@@ -2381,11 +2425,14 @@ def _sage_qk_int8_pv_fp16_triton_attention(
23812425
query: torch.Tensor,
23822426
key: torch.Tensor,
23832427
value: torch.Tensor,
2428+
attn_mask: Optional[torch.Tensor] = None,
23842429
is_causal: bool = False,
23852430
scale: Optional[float] = None,
23862431
return_lse: bool = False,
23872432
_parallel_config: Optional["ParallelConfig"] = None,
23882433
) -> torch.Tensor:
2434+
if attn_mask is not None:
2435+
raise ValueError("`attn_mask` is not supported for sage attention")
23892436
return sageattn_qk_int8_pv_fp16_triton(
23902437
q=query,
23912438
k=key,

0 commit comments

Comments
 (0)