@@ -125,6 +125,13 @@ def keep(conf):
125
125
return not (torch .cuda .get_device_capability ()[0 ] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf .num_warps == 8 )
126
126
127
127
128
+ def prune_invalid_configs (configs , named_args , ** kwargs ):
129
+ N_CTX = kwargs ["N_CTX" ]
130
+
131
+ # Filter out configs where BLOCK_M > N_CTX
132
+ return [conf for conf in configs if conf .kwargs .get ("BLOCK_M" , 0 ) <= N_CTX ]
133
+
134
+
128
135
@triton .jit
129
136
def _maybe_make_tensor_desc (desc_or_ptr , shape , strides , block_shape ):
130
137
if isinstance (desc_or_ptr , tl .tensor_descriptor ):
@@ -133,7 +140,8 @@ def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
133
140
return tl .make_tensor_descriptor (desc_or_ptr , shape , strides , block_shape )
134
141
135
142
136
- @triton .autotune (configs = list (filter (keep , configs )), key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ])
143
+ @triton .autotune (configs = list (filter (keep , configs )), key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ],
144
+ prune_configs_by = {'early_config_prune' : prune_invalid_configs })
137
145
@triton .jit
138
146
def _attn_fwd (sm_scale , M , #
139
147
Z , H , desc_q , desc_k , desc_v , desc_o , N_CTX , #
0 commit comments