Skip to content

Commit d7a9234

Browse files
authored
perf: prefetch page indices for mla kernel (flashinfer-ai#991)
Followup of flashinfer-ai#952 cc @abcdabcd987 ## Before this PR ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1509.87 GB/s FLOPs: 163.25 TFLOPs Config: batch_size=64, seq_len=1024, num_heads=128 Memory bandwidth: 1766.19 GB/s FLOPs: 345.46 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2307.97 GB/s FLOPs: 249.55 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=128 Memory bandwidth: 1975.24 GB/s FLOPs: 386.35 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2871.63 GB/s FLOPs: 310.49 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=128 Memory bandwidth: 2225.07 GB/s FLOPs: 435.21 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1948.15 GB/s FLOPs: 222.38 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=128 Memory bandwidth: 1973.36 GB/s FLOPs: 426.74 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2625.63 GB/s FLOPs: 299.72 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=128 Memory bandwidth: 2121.92 GB/s FLOPs: 458.86 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2996.11 GB/s FLOPs: 342.01 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=128 Memory bandwidth: 2146.40 GB/s FLOPs: 464.16 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=64 Memory bandwidth: 2717.28 GB/s FLOPs: 323.71 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=128 Memory bandwidth: 2129.24 GB/s FLOPs: 500.04 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=64 Memory bandwidth: 3002.75 GB/s FLOPs: 357.72 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=128 Memory bandwidth: 2101.93 GB/s FLOPs: 493.63 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=64 Memory bandwidth: 3083.42 GB/s FLOPs: 367.33 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=128 Memory bandwidth: 2064.96 GB/s FLOPs: 484.95 TFLOPs ``` ## After this PR ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1596.98 GB/s FLOPs: 172.67 TFLOPs Config: batch_size=64, seq_len=1024, num_heads=128 Memory bandwidth: 1685.22 GB/s FLOPs: 329.62 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2280.49 GB/s FLOPs: 246.58 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=128 Memory bandwidth: 1917.53 GB/s FLOPs: 375.06 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2869.03 GB/s FLOPs: 310.21 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=128 Memory bandwidth: 2208.35 GB/s FLOPs: 431.94 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 2047.44 GB/s FLOPs: 233.72 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=128 Memory bandwidth: 1936.08 GB/s FLOPs: 418.67 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2617.48 GB/s FLOPs: 298.79 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=128 Memory bandwidth: 2105.97 GB/s FLOPs: 455.41 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2999.55 GB/s FLOPs: 342.40 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=128 Memory bandwidth: 2181.54 GB/s FLOPs: 471.75 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=64 Memory bandwidth: 2780.86 GB/s FLOPs: 331.29 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=128 Memory bandwidth: 2176.12 GB/s FLOPs: 511.05 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=64 Memory bandwidth: 3031.58 GB/s FLOPs: 361.15 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=128 Memory bandwidth: 2165.73 GB/s FLOPs: 508.61 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=64 Memory bandwidth: 3126.37 GB/s FLOPs: 372.45 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=128 Memory bandwidth: 2142.42 GB/s FLOPs: 503.14 TFLOPs ```
1 parent 17ff5a7 commit d7a9234

File tree

2 files changed

+57
-24
lines changed

2 files changed

+57
-24
lines changed

include/flashinfer/attention/mla.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ __device__ void DevicePersistentMergeStates(
629629
typename KTraits::IdType* merge_partial_stride, typename KTraits::DTypeO* partial_o,
630630
float* partial_lse, typename KTraits::DTypeO* final_o, float* final_lse,
631631
const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv& num_heads) {
632-
constexpr uint32_t VEC_SIZE = 4; // partial o has data type float
632+
constexpr uint32_t VEC_SIZE = 8; // partial o has data type float
633633
constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE;
634634
constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW;
635635
const uint32_t cta_idx = (gridDim.x * blockIdx.y + blockIdx.x);

include/flashinfer/attention/mla_hopper.cuh

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,47 @@ __device__ __forceinline__ void load_q(
170170
}
171171
}
172172

173+
template <typename KTraits>
174+
__device__ __forceinline__ void prefetch_offset(
175+
const uint32_t packed_block_iter_base, const uint32_t packed_kv_bound,
176+
const uint32_t ckv_stride_page, const uint32_t ckv_stride_n, const uint32_t kpe_stride_page,
177+
const uint32_t kpe_stride_n, const uint_fastdiv& block_size, typename KTraits::IdType* indices,
178+
int64_t (*ckv_offset)[2], int64_t (*kpe_offset)[2]) {
179+
using DTypeKV = typename KTraits::DTypeKV;
180+
const uint32_t lane_idx = cutlass::canonical_lane_idx();
181+
const uint32_t warp_idx_in_wg = cutlass::canonical_warp_idx() % 4;
182+
#pragma unroll
183+
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
184+
#pragma unroll
185+
for (uint32_t j = 0; j < 2; ++j) {
186+
uint32_t q, r;
187+
uint32_t packed_block_iter =
188+
packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
189+
block_size.divmod(packed_block_iter, q, r);
190+
ckv_offset[mma_kv][j] =
191+
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
192+
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
193+
kpe_offset[mma_kv][j] =
194+
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
195+
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
196+
}
197+
}
198+
}
199+
173200
template <bool predicate, typename KTraits>
174-
__device__ __forceinline__ void load_kv(
175-
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeKV* ckv,
176-
typename KTraits::DTypeKV* kpe, typename KTraits::IdType* indices, const uint32_t ckv_stride_n,
177-
const uint32_t ckv_stride_page, const uint32_t kpe_stride_n, const uint32_t kpe_stride_page,
178-
const uint32_t packed_kv_bound, const uint32_t packed_block_iter_base,
179-
const uint_fastdiv& block_size, const uint32_t stage_idx) {
201+
__device__ __forceinline__ void load_kv(typename KTraits::SharedStorage* smem_storage,
202+
typename KTraits::DTypeKV* ckv,
203+
typename KTraits::DTypeKV* kpe,
204+
const uint32_t packed_kv_bound,
205+
const uint32_t packed_block_iter_base,
206+
const uint32_t stage_idx, int64_t (*ckv_offset)[2],
207+
int64_t (*kpe_offset)[2]) {
180208
using DTypeKV = typename KTraits::DTypeKV;
181209
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
182210
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
183211
constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
184212
constexpr uint32_t NUM_MMA_D_KPE = KTraits::NUM_MMA_D_KPE;
185213
const uint32_t lane_idx = cutlass::canonical_lane_idx();
186-
const uint32_t warp_group_idx = cutlass::canonical_warp_group_idx();
187214
const uint32_t warp_idx_in_wg = cutlass::canonical_warp_idx() % 4;
188215

189216
smem_t<KTraits::SWIZZLE_MODE_CKV> ckv_smem(smem_storage->kv_o_smem[stage_idx].ckv);
@@ -193,17 +220,11 @@ __device__ __forceinline__ void load_kv(
193220
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
194221
#pragma unroll
195222
for (uint32_t j = 0; j < 2; ++j) {
196-
uint32_t q, r;
197223
uint32_t packed_block_iter =
198224
packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
199-
block_size.divmod(packed_block_iter, q, r);
200225

201-
DTypeKV* ckv_ptr = ckv +
202-
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
203-
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
204-
DTypeKV* kpe_ptr = kpe +
205-
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
206-
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
226+
DTypeKV* ckv_ptr = ckv + ckv_offset[mma_kv][j];
227+
DTypeKV* kpe_ptr = kpe + kpe_offset[mma_kv][j];
207228
uint32_t ckv_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_CKV, UPCAST_STRIDE_CKV>(
208229
32 * mma_kv + j * 16 + warp_idx_in_wg * 4 + lane_idx / 8, 8 * 0 + lane_idx % 8);
209230
uint32_t kpe_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_KPE, UPCAST_STRIDE_KPE>(
@@ -657,6 +678,9 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
657678
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
658679
PipelineState smem_pipe_read_kv;
659680

681+
int64_t ckv_offset[KTraits::NUM_MMA_KV / 2][2];
682+
int64_t kpe_offset[KTraits::NUM_MMA_KV / 2][2];
683+
660684
#pragma unroll 1
661685
for (IdType work_idx = work_indptr[blockIdx.y]; work_idx < work_indptr[blockIdx.y + 1];
662686
++work_idx) {
@@ -681,15 +705,20 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
681705

682706
const uint32_t block_iter_base = kv_indptr * block_size + kv_start;
683707

708+
prefetch_offset<KTraits>(block_iter_base + kv_tile_idx * CTA_TILE_KV, packed_kv_bound,
709+
ckv_stride_page, ckv_stride_n, kpe_stride_page, kpe_stride_n,
710+
block_size, kv_indices, ckv_offset, kpe_offset);
684711
if (has_kv) {
685712
pipeline_kv.producer_acquire(smem_pipe_write_kv);
686-
load_kv<true, KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
687-
kpe_stride_n, kpe_stride_page, packed_kv_bound,
688-
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
689-
smem_pipe_write_kv.index());
713+
load_kv<true, KTraits>(&smem_storage, ckv, kpe, packed_kv_bound,
714+
block_iter_base + kv_tile_idx * CTA_TILE_KV,
715+
smem_pipe_write_kv.index(), ckv_offset, kpe_offset);
690716
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
691717
kv_tile_idx -= 1;
692718
++smem_pipe_write_kv;
719+
prefetch_offset<KTraits>(block_iter_base + kv_tile_idx * CTA_TILE_KV, packed_kv_bound,
720+
ckv_stride_page, ckv_stride_n, kpe_stride_page, kpe_stride_n,
721+
block_size, kv_indices, ckv_offset, kpe_offset);
693722
}
694723

695724
pipeline_q.producer_acquire(smem_pipe_write_q);
@@ -703,10 +732,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
703732
#pragma unroll 1
704733
for (; kv_tile_idx >= 0; --kv_tile_idx) {
705734
pipeline_kv.producer_acquire(smem_pipe_write_kv);
706-
load_kv<false, KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
707-
kpe_stride_n, kpe_stride_page, packed_kv_bound,
708-
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
709-
smem_pipe_write_kv.index());
735+
load_kv<false, KTraits>(&smem_storage, ckv, kpe, packed_kv_bound,
736+
block_iter_base + kv_tile_idx * CTA_TILE_KV,
737+
smem_pipe_write_kv.index(), ckv_offset, kpe_offset);
738+
if (kv_tile_idx > 0) {
739+
prefetch_offset<KTraits>(block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV,
740+
packed_kv_bound, ckv_stride_page, ckv_stride_n, kpe_stride_page,
741+
kpe_stride_n, block_size, kv_indices, ckv_offset, kpe_offset);
742+
}
710743
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
711744
++smem_pipe_write_kv;
712745

0 commit comments

Comments
 (0)