@@ -126,6 +126,13 @@ def keep(conf):
126126 and conf .num_warps == 8 )
127127
128128
129+ def prune_invalid_configs (configs , named_args , ** kwargs ):
130+ N_CTX = kwargs ["N_CTX" ]
131+
132+ # Filter out configs where BLOCK_M > N_CTX
133+ return [conf for conf in configs if conf .kwargs .get ("BLOCK_M" , 0 ) <= N_CTX ]
134+
135+
129136@triton .jit
130137def _maybe_make_tensor_desc (desc_or_ptr , shape , strides , block_shape ):
131138 if isinstance (desc_or_ptr , tl .tensor_descriptor ):
@@ -134,7 +141,8 @@ def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
134141 return tl .make_tensor_descriptor (desc_or_ptr , shape , strides , block_shape )
135142
136143
137- @triton .autotune (configs = list (filter (keep , configs )), key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ])
144+ @triton .autotune (configs = list (filter (keep , configs )), key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ],
145+ prune_configs_by = {'early_config_prune' : prune_invalid_configs })
138146@triton .jit
139147def _attn_fwd (sm_scale , M , #
140148 Z , H , desc_q , desc_k , desc_v , desc_o , N_CTX , #
0 commit comments