Skip to content

Commit e6c4956

Browse files
ThomasRaouxwhitneywhtsang
authored andcommitted
[TUTORIAL] Remove invalid config in attention (#6889)
This solves the sporadic failures
1 parent 4c27cdf commit e6c4956

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ def keep(conf):
126126
and conf.num_warps == 8)
127127

128128

129+
def prune_invalid_configs(configs, named_args, **kwargs):
130+
N_CTX = kwargs["N_CTX"]
131+
132+
# Filter out configs where BLOCK_M > N_CTX
133+
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
134+
135+
129136
@triton.jit
130137
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
131138
if isinstance(desc_or_ptr, tl.tensor_descriptor):
@@ -134,7 +141,8 @@ def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
134141
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
135142

136143

137-
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"])
144+
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
145+
prune_configs_by={'early_config_prune': prune_invalid_configs})
138146
@triton.jit
139147
def _attn_fwd(sm_scale, M, #
140148
Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #

0 commit comments

Comments
 (0)