Skip to content

Commit cd60a83

Browse files
committed
Include HAS_MASK/HAS_BIAS/HAS_INDICE in autotune key; tighten mask/bias dtype checks
Add HAS_MASK, HAS_BIAS and HAS_INDICE to the autotune key to ensure different kernel configs are cached per mask/bias/indice usage. Also enforce bias dtype to match query dtype (only fp16/bf16) and standardize the mask dtype assert message.
1 parent 730f40e commit cd60a83

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

flash_dmattn/flash_dmattn_triton.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def init_func(nargs):
688688
pre_hook=init_to_zero(["DQ", "DBias"]),
689689
),
690690
],
691-
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"],
691+
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "HAS_MASK", "HAS_BIAS", "HAS_INDICE", "BLOCK_HEADDIM"],
692692
)
693693
@triton.heuristics(
694694
{
@@ -903,7 +903,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False
903903

904904
has_mask = mask is not None
905905
if has_mask:
906-
assert mask.dtype == torch.bool, "Only support bool mask"
906+
assert mask.dtype == torch.bool, "Only support bool"
907907
assert mask.is_cuda
908908
nheads_mask = mask.shape[1]
909909
else:
@@ -912,7 +912,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False
912912

913913
has_bias = bias is not None
914914
if has_bias:
915-
assert bias.dtype in [q.dtype, torch.float]
915+
assert bias.dtype == q.dtype, "Only support fp16 and bf16"
916916
assert bias.is_cuda
917917
nheads_bias = bias.shape[1]
918918
else:
@@ -999,15 +999,15 @@ def _flash_attn_backward(
999999

10001000
has_mask = mask is not None
10011001
if has_mask:
1002-
assert mask.dtype == torch.bool, "Only support bool mask"
1002+
assert mask.dtype == torch.bool, "Only support bool"
10031003
nheads_mask = mask.shape[1]
10041004
else:
10051005
nheads_mask = 1
10061006
mask = torch.empty(0, device=q.device, dtype=torch.bool)
10071007

10081008
has_bias = bias is not None
10091009
if has_bias:
1010-
assert bias.dtype in [q.dtype, torch.float]
1010+
assert bias.dtype == q.dtype, "Only support fp16 and bf16"
10111011
nheads_bias = bias.shape[1]
10121012
else:
10131013
nheads_bias = 1

0 commit comments

Comments
 (0)