Skip to content

Commit 9bdb282

Browse files
committed
Fix autotune key typo and update configuration options for forward base kernel
1 parent 2b5e339 commit 9bdb282

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

flash_sparse_attn/ops/triton/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def maybe_to_contiguous(x):
6666
return wrapper
6767

6868

69-
FWD_BASE_AUTOTUNE_KAYS = ["IS_CAUSAL", "IS_LOCAL", "TILE_K"]
69+
FWD_BASE_AUTOTUNE_KEYS = ["seqlen_q", "seqlen_k", "IS_CAUSAL", "IS_LOCAL", "TILE_K"]
7070

7171

7272
def get_fwd_base_autotune_configs(autotune: bool):
@@ -121,14 +121,14 @@ def get_fwd_base_autotune_configs(autotune: bool):
121121

122122
configs = []
123123
BLOCK_M_OPTIONS = [256, 128, 64, 32]
124-
BLCOK_N_OPTIONS = [256, 128, 64, 32]
125-
NUM_WARPS_OPTIONS = [2, 4, 8]
126-
NUM_STAGES_OPTION = [1, 2]
124+
BLOCK_N_OPTIONS = [256, 128, 64, 32]
125+
NUM_WARPS_OPTIONS = [4, 8]
126+
NUM_STAGES_OPTIONS = [1, 2]
127127

128128
for bm in BLOCK_M_OPTIONS:
129-
for bn in BLCOK_N_OPTIONS:
129+
for bn in BLOCK_N_OPTIONS:
130130
for nw in NUM_WARPS_OPTIONS:
131-
for ns in NUM_STAGES_OPTION:
131+
for ns in NUM_STAGES_OPTIONS:
132132
configs.append(
133133
triton.Config(
134134
{
@@ -214,7 +214,7 @@ def assert_fwd_base_inputs(
214214
"head_dim must be a multiple of 16 for efficient memory access"
215215
)
216216
assert head_dim <= 256, (
217-
"head_dim must be less than or equal to 256 for efficient memory access`"
217+
"head_dim must be less than or equal to 256 for efficient memory access"
218218
)
219219
if cu_seqlens_q is not None and cu_seqlens_k is not None:
220220
assert cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda, (

0 commit comments

Comments
 (0)