Skip to content

Commit 11b6747

Browse files
authored
[TUTORIAL] Remove invalid config in attention (#6889)
This solves the sporadic failures
1 parent dca70ac commit 11b6747

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def keep(conf):
125125
return not (torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8)
126126

127127

128+
def prune_invalid_configs(configs, named_args, **kwargs):
129+
N_CTX = kwargs["N_CTX"]
130+
131+
# Filter out configs where BLOCK_M > N_CTX
132+
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
133+
134+
128135
@triton.jit
129136
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
130137
if isinstance(desc_or_ptr, tl.tensor_descriptor):
@@ -133,7 +140,8 @@ def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
133140
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
134141

135142

136-
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"])
143+
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
144+
prune_configs_by={'early_config_prune': prune_invalid_configs})
137145
@triton.jit
138146
def _attn_fwd(sm_scale, M, #
139147
Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #

0 commit comments

Comments
 (0)