Skip to content

Commit 47b9339

Browse files
[DeepSeek] Improve performance of DS MLA cache kernel (vllm-project#26132)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent 5d5146e commit 47b9339

File tree

1 file changed

+62
-68
lines changed

1 file changed

+62
-68
lines changed

csrc/cache_kernels.cu

Lines changed: 62 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include <algorithm>
1818
#include <cassert>
19-
#include <cfloat> // FLT_MIN
2019
#include <map>
2120
#include <vector>
2221

@@ -424,84 +423,80 @@ __global__ void concat_and_cache_ds_mla_kernel(
424423
const int64_t dst_idx_start =
425424
block_idx * block_stride + block_offset * entry_stride;
426425

427-
// Create 4 tile scales in shared memory
428-
__shared__ float smem[20];
429-
float* shard_abs_max = smem;
430-
float* tile_scales = smem + 16;
431-
432-
// For the NoPE part, each tile of 128 elements is handled by 4 warps
433-
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
434-
// The first thread of the first warp in each tile writes the scale
435-
// value for the tile. The RoPE part (last 64 elements) is handled
436-
// by another 2 warps (64 threads).
437-
// So in total, we use 18 warps (576 threads) per block.
426+
// For the NoPE part, each tile of 128 elements is handled by half of one warp
427+
// (16 threads). There are 4 total tiles, so 2 warps (64 threads).
428+
// Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
429+
// The RoPE part (last 64 elements) is handled by another 1 warp (32 threads).
430+
// So in total, we use 3 warps (96 threads) per block.
438431

439432
// Cast kv_cache to 16_bit for RoPE values
440433
scalar_t* kv_cache_16bit =
441434
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
442435

443-
// The last 64 threads handle the RoPE part
444-
if (threadIdx.x >= kv_lora_rank) {
445-
const int8_t pe_idx = threadIdx.x - kv_lora_rank;
446-
const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
436+
// The last warp handles the RoPE part
437+
if (threadIdx.x >= 64) {
438+
// Each thread handles two elements of RoPE
439+
const int8_t pe_idx_start = (threadIdx.x - 64) * 2;
440+
const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start;
441+
// Vectorized load of two 16-bit values, performed as one 32-bit load
442+
const int32_t vals = *reinterpret_cast<const int32_t*>(&k_pe[src_idx]);
447443
// RoPE values start after the packed 8-bit NoPE values and the
448444
// 32-bit scales
449-
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
450-
kv_cache_16bit[dst_idx] = k_pe[src_idx];
445+
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start;
446+
// Vectorized store of two 16-bit values, performed as one 32-bit store
447+
*reinterpret_cast<int32_t*>(&kv_cache_16bit[dst_idx]) = vals;
451448
return;
452449
}
453450

454-
// Determine the scale for each chunk of NoPE
455-
const int16_t tile_idx = threadIdx.x >> 7;
456-
const int16_t warp_idx = (threadIdx.x & 127) >> 5;
457-
const int16_t lane_idx = threadIdx.x & 31;
458-
459-
// Load the NoPE element for this thread into registers
460-
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
461-
const scalar_t src_val = kv_c[src_idx];
462-
463-
// Warp-level reduction to find the max absolute value in the warp
464-
float max_abs = fabsf(src_val);
451+
// The first two warps handle the NoPE part
452+
const int8_t warp_idx = threadIdx.x >> 5;
453+
const int8_t lane_idx = threadIdx.x & 31;
454+
const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4);
455+
456+
// Each thread handles 8 elements of NoPE
457+
// Load the NoPE elements for this thread into registers
458+
const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8);
459+
// Vectorized load of eight 16-bit values, performed as an int4 load
460+
const int4 vals_i4 = *reinterpret_cast<const int4*>(&kv_c[src_idx_start]);
461+
const scalar_t* vals = reinterpret_cast<const scalar_t*>(&vals_i4);
462+
463+
// Max absolute value of this thread's elements
464+
float max_abs = fmaxf(fmaxf(fmaxf(fabsf(vals[0]), fabsf(vals[1])),
465+
fmaxf(fabsf(vals[2]), fabsf(vals[3]))),
466+
fmaxf(fmaxf(fabsf(vals[4]), fabsf(vals[5])),
467+
fmaxf(fabsf(vals[6]), fabsf(vals[7]))));
468+
469+
// Warp-level reduction to find the max absolute value in each half-warp
465470
#pragma unroll
466-
for (int offset = 16; offset > 0; offset /= 2) {
467-
#ifdef USE_ROCM
468-
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
469-
#else
470-
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
471-
#endif
471+
for (int offset = 8; offset > 0; offset /= 2) {
472+
max_abs = fmaxf(max_abs, VLLM_SHFL_XOR_SYNC_WIDTH(max_abs, offset, 16));
472473
}
473474

474-
// The first lane of each warp in each tile writes the max_abs of this part
475-
// of the tile to shared memory
476-
if (lane_idx == 0) {
477-
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
478-
}
479-
__syncthreads();
480-
481-
// The first lane of the first warp in each tile computes the scale for the
482-
// tile and writes it to shared memory and to kv_cache
483-
if (warp_idx == 0 && lane_idx == 0) {
484-
float4 shard_abs_max_vec =
485-
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
486-
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
487-
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
488-
448.f;
489-
490-
// Avoid division by zero in `scaled_convert`
491-
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
475+
// Compute the scale for the tile
476+
float tile_scale = max_abs / 448.f;
477+
478+
// The first lane of each half-warp writes the scale to kv_cache
479+
if ((lane_idx == 0) || (lane_idx == 16)) {
492480
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
493481
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
494-
kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
482+
kv_cache_32bit[dst_idx] = tile_scale;
495483
}
496484

497-
__syncthreads();
485+
// Now all threads in the block scale and write their elements
486+
// NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes)
487+
const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8);
488+
489+
uint8_t result[8];
490+
#pragma unroll
491+
for (int i = 0; i < 8; i++) {
492+
result[i] =
493+
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
494+
vals[i], tile_scale);
495+
}
498496

499-
// Now all threads in the block scale and write their element
500-
const float scale_val = tile_scales[tile_idx];
501-
const int64_t dst_idx = dst_idx_start + threadIdx.x;
502-
kv_cache[dst_idx] =
503-
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
504-
src_val, scale_val);
497+
// Store as aligned 64-bit writes
498+
*reinterpret_cast<uint64_t*>(&kv_cache[dst_idx_base]) =
499+
*reinterpret_cast<const uint64_t*>(result);
505500
}
506501

507502
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
@@ -741,13 +736,12 @@ void concat_and_cache_mla(
741736

742737
if (kv_cache_dtype == "fp8_ds_mla") {
743738
dim3 grid(num_tokens);
744-
// For the NoPE part, each tile of 128 elements is handled by 4 warps
745-
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
746-
// The first thread of the first warp in each tile writes the scale
747-
// value for the tile. The RoPE part (last 64 elements) is handled
748-
// by another 2 warps (64 threads).
749-
// So in total, we use 18 warps (576 threads) per block.
750-
dim3 block(576);
739+
// For the NoPE part, each tile of 128 elements is handled by half of one
740+
// warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
741+
// Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
742+
// The RoPE part (last 64 elements) is handled by another 1 warp (32
743+
// threads). So in total, we use 3 warps (96 threads) per block.
744+
dim3 block(96);
751745
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
752746
CALL_CONCAT_AND_CACHE_DS_MLA);
753747
} else {

0 commit comments

Comments
 (0)