@@ -66,7 +66,7 @@ def maybe_to_contiguous(x):
6666 return wrapper
6767
6868
69- FWD_BASE_AUTOTUNE_KAYS = ["IS_CAUSAL" , "IS_LOCAL" , "TILE_K" ]
69+ FWD_BASE_AUTOTUNE_KEYS = ["seqlen_q" , "seqlen_k" , "IS_CAUSAL" , "IS_LOCAL" , "TILE_K" ]
7070
7171
7272def get_fwd_base_autotune_configs (autotune : bool ):
@@ -121,14 +121,14 @@ def get_fwd_base_autotune_configs(autotune: bool):
121121
122122 configs = []
123123 BLOCK_M_OPTIONS = [256 , 128 , 64 , 32 ]
124- BLCOK_N_OPTIONS = [256 , 128 , 64 , 32 ]
125- NUM_WARPS_OPTIONS = [2 , 4 , 8 ]
126- NUM_STAGES_OPTION = [1 , 2 ]
124+ BLOCK_N_OPTIONS = [256 , 128 , 64 , 32 ]
125+ NUM_WARPS_OPTIONS = [4 , 8 ]
126+ NUM_STAGES_OPTIONS = [1 , 2 ]
127127
128128 for bm in BLOCK_M_OPTIONS :
129- for bn in BLCOK_N_OPTIONS :
129+ for bn in BLOCK_N_OPTIONS :
130130 for nw in NUM_WARPS_OPTIONS :
131- for ns in NUM_STAGES_OPTION :
131+ for ns in NUM_STAGES_OPTIONS :
132132 configs .append (
133133 triton .Config (
134134 {
@@ -214,7 +214,7 @@ def assert_fwd_base_inputs(
214214 "head_dim must be a multiple of 16 for efficient memory access"
215215 )
216216 assert head_dim <= 256 , (
217- "head_dim must be less than or equal to 256 for efficient memory access` "
217+ "head_dim must be less than or equal to 256 for efficient memory access"
218218 )
219219 if cu_seqlens_q is not None and cu_seqlens_k is not None :
220220 assert cu_seqlens_q .is_cuda and cu_seqlens_k .is_cuda , (
0 commit comments