@@ -126,6 +126,13 @@ def keep(conf):
126
126
and conf .num_warps == 8 )
127
127
128
128
129
+ def prune_invalid_configs (configs , named_args , ** kwargs ):
130
+ N_CTX = kwargs ["N_CTX" ]
131
+
132
+ # Filter out configs where BLOCK_M > N_CTX
133
+ return [conf for conf in configs if conf .kwargs .get ("BLOCK_M" , 0 ) <= N_CTX ]
134
+
135
+
129
136
@triton .jit
130
137
def _maybe_make_tensor_desc (desc_or_ptr , shape , strides , block_shape ):
131
138
if isinstance (desc_or_ptr , tl .tensor_descriptor ):
@@ -134,7 +141,8 @@ def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
134
141
return tl .make_tensor_descriptor (desc_or_ptr , shape , strides , block_shape )
135
142
136
143
137
- @triton .autotune (configs = list (filter (keep , configs )), key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ])
144
+ @triton .autotune (configs = list (filter (keep , configs )), key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ],
145
+ prune_configs_by = {'early_config_prune' : prune_invalid_configs })
138
146
@triton .jit
139
147
def _attn_fwd (sm_scale , M , #
140
148
Z , H , desc_q , desc_k , desc_v , desc_o , N_CTX , #
0 commit comments