@@ -525,7 +525,6 @@ def _attention_prefill(h_kv, h_q, d, dtype, sliding_window: bool, target: Target
525
525
bdx = 32
526
526
num_warps = 4
527
527
tile_x , tile_y , tile_z = 64 // ((DataType (dtype ).bits + 7 ) // 8 ) // max (d // 128 , 1 ), d , 16
528
- L_per_cta = tile_x // group_size
529
528
530
529
# Otherwise we would exceed maxComputeWorkgroupStorageSize
531
530
if (
@@ -641,8 +640,7 @@ def batch_prefill_paged_kv(
641
640
642
641
if T .tvm_thread_invariant (batch_idx [0 ] < batch_size ):
643
642
b_idx : T .int32 = batch_idx [0 ]
644
- L_start : T .int32 = q_indptr [b_idx ] + tile_id [0 ] * L_per_cta
645
- H_qo_start : T .int32 = by * group_size
643
+ LH_start : T .int32 = tile_id [0 ] * tile_x
646
644
647
645
cur_page_indptr_begin : T .int32 = page_indptr [b_idx ]
648
646
cur_page_indptr_end : T .int32 = page_indptr [b_idx + 1 ]
@@ -672,8 +670,8 @@ def batch_prefill_paged_kv(
672
670
i , j = T .axis .remap ("SS" , [li , lj ])
673
671
T .reads ()
674
672
T .writes ()
675
- cur_L = L_start + i // group_size
676
- cur_H_qo = H_qo_start + i % group_size
673
+ cur_L = q_indptr [ b_idx ] + ( LH_start + i ) // group_size
674
+ cur_H_qo = by * group_size + ( LH_start + i ) % group_size
677
675
if cur_L < q_indptr [b_idx + 1 ]:
678
676
Q_smem [i , j ] = T .if_then_else (
679
677
rotary_mode == 1 ,
@@ -742,9 +740,10 @@ def batch_prefill_paged_kv(
742
740
m_prev [i ] = m_smem [row ]
743
741
m_new [i ] = m_smem [row ]
744
742
# mask out of kv_chunk_len S
743
+ row_ : T .int32 = (LH_start + row ) // group_size
745
744
for j in T .serial (tile_z ):
746
745
if _causal_mask (causal ,
747
- row = tile_id [ 0 ] * L_per_cta + row // group_size ,
746
+ row = row_ ,
748
747
col = L_kv_start + j ,
749
748
kv_len = kv_chunk_len [0 ],
750
749
qo_len = q_indptr [b_idx + 1 ] - q_indptr [b_idx ]):
@@ -757,8 +756,9 @@ def batch_prefill_paged_kv(
757
756
for j in T .serial (tile_z ):
758
757
# this is to avoid sync inside condition branch
759
758
if row < tile_x :
759
+ row_ : T .int32 = (LH_start + row ) // group_size
760
760
if _causal_mask (causal ,
761
- row = tile_id [ 0 ] * L_per_cta + row // group_size ,
761
+ row = row_ ,
762
762
col = L_kv_start + j ,
763
763
kv_len = kv_chunk_len [0 ],
764
764
qo_len = q_indptr [b_idx + 1 ] - q_indptr [b_idx ]):
@@ -790,15 +790,19 @@ def batch_prefill_paged_kv(
790
790
for li , lj in T .grid (tile_x , tile_y ):
791
791
with T .block ("O_store" ):
792
792
i , j = T .axis .remap ("SS" , [li , lj ])
793
- if L_start + i // group_size < q_indptr [b_idx + 1 ]:
794
- output [L_start + i // group_size , H_qo_start + i % group_size , j ] = O_local [i , j ] / d_smem [i ]
793
+ cur_L : T .int32 = q_indptr [b_idx ] + (LH_start + i ) // group_size
794
+ cur_H_qo : T .int32 = by * group_size + (LH_start + i ) % group_size
795
+ if cur_L < q_indptr [b_idx + 1 ]:
796
+ output [cur_L , cur_H_qo , j ] = O_local [i , j ] / d_smem [i ]
795
797
796
798
# Store LSE to gmem
797
799
for li in T .grid (tile_x ):
798
800
with T .block ("lse_store" ):
799
801
i = T .axis .remap ("S" , [li ])
800
- if L_start + i // group_size < q_indptr [b_idx + 1 ]:
801
- lse [L_start + i // group_size , H_qo_start + i % group_size ] = m_smem [i ] + T .log2 (d_smem [i ])
802
+ cur_L : T .int32 = q_indptr [b_idx ] + (LH_start + i ) // group_size
803
+ cur_H_qo : T .int32 = by * group_size + (LH_start + i ) % group_size
804
+ if cur_L < q_indptr [b_idx + 1 ]:
805
+ lse [cur_L , cur_H_qo ] = m_smem [i ] + T .log2 (d_smem [i ])
802
806
803
807
# move to next tile
804
808
tile_id [0 ] += NUM_BLKS
@@ -1218,7 +1222,6 @@ def _attention_prefill_ragged(
1218
1222
bdx = 32
1219
1223
num_warps = 4
1220
1224
tile_x , tile_y , tile_z = 64 // ((DataType (dtype ).bits + 7 ) // 8 ) // max (d // 128 , 1 ), d , 16
1221
- L_per_cta = tile_x // group_size
1222
1225
1223
1226
# Otherwise we would exceed maxComputeWorkgroupStorageSize
1224
1227
if (
@@ -1313,8 +1316,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
1313
1316
1314
1317
if T .tvm_thread_invariant (batch_idx [0 ] < batch_size ):
1315
1318
b_idx : T .int32 = batch_idx [0 ]
1316
- L_start : T .int32 = q_indptr [b_idx ] + tile_id [0 ] * L_per_cta
1317
- H_qo_start : T .int32 = by * group_size
1319
+ LH_start : T .int32 = tile_id [0 ] * tile_x
1318
1320
1319
1321
kv_chunk_len [0 ] = kv_indptr [b_idx + 1 ] - kv_indptr [b_idx ]
1320
1322
T .tvm_storage_sync ("shared" )
@@ -1338,8 +1340,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
1338
1340
i , j = T .axis .remap ("SS" , [li , lj ])
1339
1341
T .reads ()
1340
1342
T .writes ()
1341
- cur_L = L_start + i // group_size
1342
- cur_H_qo = H_qo_start + i % group_size
1343
+ cur_L = q_indptr [ b_idx ] + ( LH_start + i ) // group_size
1344
+ cur_H_qo = by * group_size + ( LH_start + i ) % group_size
1343
1345
if cur_L < q_indptr [b_idx + 1 ]:
1344
1346
Q_smem [i , j ] = T .if_then_else (
1345
1347
rotary_mode == 1 ,
@@ -1403,9 +1405,10 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
1403
1405
m_prev [i ] = m_smem [row ]
1404
1406
m_new [i ] = m_smem [row ]
1405
1407
# mask out of kv_chunk_len S
1408
+ row_ : T .int32 = (LH_start + row ) // group_size
1406
1409
for j in T .serial (tile_z ):
1407
1410
if _causal_mask (causal ,
1408
- row = tile_id [ 0 ] * L_per_cta + row // group_size ,
1411
+ row = row_ ,
1409
1412
col = L_kv_start + j ,
1410
1413
kv_len = kv_chunk_len [0 ],
1411
1414
qo_len = q_indptr [b_idx + 1 ] - q_indptr [b_idx ]):
@@ -1418,8 +1421,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
1418
1421
for j in T .serial (tile_z ):
1419
1422
# this is to avoid sync inside condition branch
1420
1423
if row < tile_x :
1424
+ row_ : T .int32 = (LH_start + row ) // group_size
1421
1425
if _causal_mask (causal ,
1422
- row = tile_id [ 0 ] * L_per_cta + row // group_size ,
1426
+ row = row_ ,
1423
1427
col = L_kv_start + j ,
1424
1428
kv_len = kv_chunk_len [0 ],
1425
1429
qo_len = q_indptr [b_idx + 1 ] - q_indptr [b_idx ]):
@@ -1451,15 +1455,19 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
1451
1455
for li , lj in T .grid (tile_x , tile_y ):
1452
1456
with T .block ("O_store" ):
1453
1457
i , j = T .axis .remap ("SS" , [li , lj ])
1454
- if L_start + i // group_size < q_indptr [b_idx + 1 ]:
1455
- output [L_start + i // group_size , H_qo_start + i % group_size , j ] = O_local [i , j ] / d_smem [i ]
1458
+ cur_L : T .int32 = q_indptr [b_idx ] + (LH_start + i ) // group_size
1459
+ cur_H_qo : T .int32 = by * group_size + (LH_start + i ) % group_size
1460
+ if cur_L < q_indptr [b_idx + 1 ]:
1461
+ output [cur_L , cur_H_qo , j ] = O_local [i , j ] / d_smem [i ]
1456
1462
1457
1463
# Store LSE to gmem
1458
1464
for li in T .grid (tile_x ):
1459
1465
with T .block ("lse_store" ):
1460
1466
i = T .axis .remap ("S" , [li ])
1461
- if L_start + i // group_size < q_indptr [b_idx + 1 ]:
1462
- lse [L_start + i // group_size , H_qo_start + i % group_size ] = m_smem [i ] + T .log2 (d_smem [i ])
1467
+ cur_L : T .int32 = q_indptr [b_idx ] + (LH_start + i ) // group_size
1468
+ cur_H_qo : T .int32 = by * group_size + (LH_start + i ) % group_size
1469
+ if cur_L < q_indptr [b_idx + 1 ]:
1470
+ lse [cur_L , cur_H_qo ] = m_smem [i ] + T .log2 (d_smem [i ])
1463
1471
1464
1472
# move to next tile
1465
1473
tile_id [0 ] += NUM_BLKS
0 commit comments