@@ -43,9 +43,30 @@ def _host_descriptor_pre_hook(nargs):
4343 num_warps = 4 ,
4444 pre_hook = _host_descriptor_pre_hook ,
4545 ),
46+ triton .Config (
47+ {
48+ "BLOCK_M" : 256 ,
49+ "BLOCK_N" : 128 ,
50+ "NUM_BUFFERS_Q" : 1 ,
51+ "NUM_BUFFERS_KV" : 6 ,
52+ "NUM_BUFFERS_QK" : 1 ,
53+ "NUM_MMA_GROUPS" : 2 ,
54+ "NUM_MMA_SLICES" : 2 ,
55+ },
56+ num_stages = 0 ,
57+ num_warps = 4 ,
58+ pre_hook = _host_descriptor_pre_hook ,
59+ ),
4660]
4761
4862
63+ def prune_configs_by_hdim (configs , named_args , ** kwargs ):
64+ HEAD_DIM = kwargs ["HEAD_DIM" ]
65+ target_kv_buffers = 6 if HEAD_DIM == 64 else 3
66+ # Only match HEAD_DIM for BLOCK_N
67+ return [conf for conf in configs if conf .kwargs .get ("NUM_BUFFERS_KV" , 0 ) == target_kv_buffers ]
68+
69+
4970@triton .jit
5071def _get_bufidx_phase (accum_cnt , NUM_BUFFERS_KV ):
5172 bufIdx = accum_cnt % NUM_BUFFERS_KV
@@ -161,15 +182,15 @@ def _mask_scalar(qk, col_limit_right, s, i):
161182
162183
163184@triton .jit
164- def _apply_causal_mask (qk , col_limit_right , HEAD_DIM : tl .constexpr ):
185+ def _apply_causal_mask (qk , col_limit_right , BLOCK_N : tl .constexpr ):
165186 # Apply causal mask via a bitmask calculated for each block of 16 elements.
166187 # This allows the efficient R2P (register to predicate) instruction to be used at the SASS level.
167188 # Credit to Tri Dao,
168189 # https://github.com/Dao-AILab/flash-attention/commit/bac1001e4f6caa09d70537495d6746a685a2fa78
169190 #
170191 # NOTE: We use map_elementiwse here in order to generate an interleaved sequence of instructions
171192 # that processes one element of qk at a time. This improves ptxas's resulting SASS.
172- offs_n = tl .arange (0 , HEAD_DIM )[None , :]
193+ offs_n = tl .arange (0 , BLOCK_N )[None , :]
173194 s = offs_n & ~ 0xF
174195 i = offs_n & 0xF
175196 return tl .map_elementwise (_mask_scalar , qk , col_limit_right , s , i )
@@ -209,16 +230,16 @@ def _softmax_inner_loop(
209230
210231 if STAGE == 2 :
211232 col_limit_right = (offs_m - start_n + 1 )[:, None ]
212- qk = _apply_causal_mask (qk , col_limit_right , HEAD_DIM )
233+ qk = _apply_causal_mask (qk , col_limit_right , BLOCK_N )
213234
214235 # compute m_i, p in registers
215236 m_ij = tl .maximum (m_i , tl .max (qk , 1 ) * qk_scale )
216237
217238 # -- compute correction factor
218239 alpha = tl .math .exp2 (m_i - m_ij )
219240 tlx .barrier_wait (tlx .local_view (alpha_empties , cid ), qk_phase ^ 1 )
220- # Use alpha[0] for cid=0, and alpha[HEAD_DIM ] for cid=1
221- tlx .local_store (tlx .local_view (alpha_tiles , cid * HEAD_DIM ), alpha [:, None ])
241+ # Use alpha[0] for cid=0, and alpha[BLOCK_N ] for cid=1
242+ tlx .local_store (tlx .local_view (alpha_tiles , cid * BLOCK_N ), alpha [:, None ])
222243 tlx .barrier_arrive (tlx .local_view (alpha_fulls , cid ))
223244
224245 qk = _fma_f32x2 (qk , qk_scale , - m_ij [:, None ])
@@ -243,7 +264,11 @@ def _softmax_inner_loop(
243264 return m_i , l_i , accum_cnt_qk
244265
245266
246- @triton .autotune (configs = configs , key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "STAGE" ])
267+ @triton .autotune (
268+ configs = configs ,
269+ key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "STAGE" ],
270+ prune_configs_by = {"early_config_prune" : prune_configs_by_hdim },
271+ )
247272@triton .jit
248273def _attn_fwd_ws (sm_scale , M , #
249274 Z , H , desc_q , desc_k , desc_v , desc_o , N_CTX , #
@@ -258,7 +283,6 @@ def _attn_fwd_ws(sm_scale, M, #
258283 NUM_MMA_GROUPS : tl .constexpr , #
259284 NUM_MMA_SLICES : tl .constexpr , #
260285 ):
261- tl .static_assert (BLOCK_N <= HEAD_DIM )
262286 tl .static_assert (NUM_MMA_GROUPS == 2 )
263287 tl .static_assert (NUM_BUFFERS_QK == 1 )
264288 tl .static_assert (NUM_BUFFERS_Q == 1 )
@@ -357,8 +381,8 @@ def _attn_fwd_ws(sm_scale, M, #
357381 for cid in tl .static_range (0 , NUM_MMA_GROUPS ):
358382 # -- update output accumulator --
359383 tlx .barrier_wait (alpha_fulls [cid ], phase )
360- # Use alpha[0] for cid=0, and alpha[HEAD_DIM ] for cid=1
361- alpha_1 = tlx .local_load (alpha_tiles [cid * HEAD_DIM ])
384+ # Use alpha[0] for cid=0, and alpha[BLOCK_N ] for cid=1
385+ alpha_1 = tlx .local_load (alpha_tiles [cid * BLOCK_N ])
362386 tlx .barrier_arrive (alpha_empties [cid ])
363387 for slice_id in tl .static_range (0 , NUM_MMA_SLICES ):
364388 subslice = tlx .subslice (
@@ -377,11 +401,11 @@ def _attn_fwd_ws(sm_scale, M, #
377401 for cid in tl .static_range (0 , NUM_MMA_GROUPS ):
378402 # epilogue
379403 tlx .barrier_wait (l_fulls [cid ], phase )
380- # Use l[1]/l[1+HEAD_DIM ] and m[2][2 + HEAD_DIM ]
381- # to disambigulate from alpha[0]/alpha[HEAD_DIM ]
382- l = tlx .local_load (l_tiles [cid * HEAD_DIM + 1 ])
404+ # Use l[1]/l[1+BLOCK_N ] and m[2][2 + BLOCK_N ]
405+ # to disambigulate from alpha[0]/alpha[BLOCK_N ]
406+ l = tlx .local_load (l_tiles [cid * BLOCK_N + 1 ])
383407 tlx .barrier_arrive (qk_empties [cid ])
384- m = tlx .local_load (m_tiles [cid * HEAD_DIM + 2 ])
408+ m = tlx .local_load (m_tiles [cid * BLOCK_N + 2 ])
385409 m += tl .math .log2 (l )
386410 offs_m = (start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl .arange (0 , BLOCK_M_SPLIT ))
387411 m_ptrs = M + off_hz * N_CTX + offs_m
@@ -479,10 +503,10 @@ def _attn_fwd_ws(sm_scale, M, #
479503 )
480504
481505 # prepare l_i for the epilog
482- # Use l[1]/l[1+HEAD_DIM ] and m[2][2 + HEAD_DIM ]
483- # to disambigulate from alpha[0]/alpha[HEAD_DIM ]
484- tlx .local_store (l_tiles [cid * HEAD_DIM + 1 ], l_i [:, None ])
485- tlx .local_store (m_tiles [cid * HEAD_DIM + 2 ], m_i [:, None ])
506+ # Use l[1]/l[1+BLOCK_N ] and m[2][2 + BLOCK_N ]
507+ # to disambigulate from alpha[0]/alpha[BLOCK_N ]
508+ tlx .local_store (l_tiles [cid * BLOCK_N + 1 ], l_i [:, None ])
509+ tlx .local_store (m_tiles [cid * BLOCK_N + 2 ], m_i [:, None ])
486510 tlx .barrier_arrive (l_fulls [cid ])
487511 tile_idx += num_progs
488512
@@ -1621,7 +1645,7 @@ def grid(meta):
16211645@pytest .mark .parametrize ("Z" , [8 ])
16221646@pytest .mark .parametrize ("H" , [16 ])
16231647@pytest .mark .parametrize ("N_CTX" , [1024 ])
1624- @pytest .mark .parametrize ("HEAD_DIM" , [128 ])
1648+ @pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
16251649@pytest .mark .parametrize ("mode" , ["fwd" , "bwd" ])
16261650@pytest .mark .parametrize ("provider" , ["triton-fp16" ])
16271651@pytest .mark .parametrize ("causal" , [True , False ])
@@ -1633,7 +1657,9 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, causal, dtype=torch.float16):
16331657 sm_scale = 0.5
16341658 # reference implementation
16351659 ref_dtype = dtype
1636- if mode == "fwd" and not causal :
1660+ if mode == "bwd" and HEAD_DIM == 64 :
1661+ pytest .skip ("Only test bwd with 128" )
1662+ elif mode == "fwd" and not causal and HEAD_DIM == 128 :
16371663 pytest .skip ("Only test fwd with causal" )
16381664 elif mode == "bwd" and causal :
16391665 pytest .skip ("Causal not supported for bwd yet" )
0 commit comments