|
16 | 16 |
|
17 | 17 | #include <algorithm>
|
18 | 18 | #include <cassert>
|
19 |
| -#include <cfloat> // FLT_MIN |
20 | 19 | #include <map>
|
21 | 20 | #include <vector>
|
22 | 21 |
|
@@ -424,84 +423,80 @@ __global__ void concat_and_cache_ds_mla_kernel(
|
424 | 423 | const int64_t dst_idx_start =
|
425 | 424 | block_idx * block_stride + block_offset * entry_stride;
|
426 | 425 |
|
427 |
| - // Create 4 tile scales in shared memory |
428 |
| - __shared__ float smem[20]; |
429 |
| - float* shard_abs_max = smem; |
430 |
| - float* tile_scales = smem + 16; |
431 |
| - |
432 |
| - // For the NoPE part, each tile of 128 elements is handled by 4 warps |
433 |
| - // (128 threads). There are 4 total tiles, so 16 warps (512 threads). |
434 |
| - // The first thread of the first warp in each tile writes the scale |
435 |
| - // value for the tile. The RoPE part (last 64 elements) is handled |
436 |
| - // by another 2 warps (64 threads). |
437 |
| - // So in total, we use 18 warps (576 threads) per block. |
| 426 | + // For the NoPE part, each tile of 128 elements is handled by half of one warp |
| 427 | + // (16 threads). There are 4 total tiles, so 2 warps (64 threads). |
| 428 | + // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. |
| 429 | + // The RoPE part (last 64 elements) is handled by another 1 warp (32 threads). |
| 430 | + // So in total, we use 3 warps (96 threads) per block. |
438 | 431 |
|
439 | 432 | // Cast kv_cache to 16_bit for RoPE values
|
440 | 433 | scalar_t* kv_cache_16bit =
|
441 | 434 | reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
|
442 | 435 |
|
443 |
| - // The last 64 threads handle the RoPE part |
444 |
| - if (threadIdx.x >= kv_lora_rank) { |
445 |
| - const int8_t pe_idx = threadIdx.x - kv_lora_rank; |
446 |
| - const int64_t src_idx = token_idx * k_pe_stride + pe_idx; |
| 436 | + // The last warp handles the RoPE part |
| 437 | + if (threadIdx.x >= 64) { |
| 438 | + // Each thread handles two elements of RoPE |
| 439 | + const int8_t pe_idx_start = (threadIdx.x - 64) * 2; |
| 440 | + const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start; |
| 441 | + // Vectorized load of two 16-bit values, performed as one 32-bit load |
| 442 | + const int32_t vals = *reinterpret_cast<const int32_t*>(&k_pe[src_idx]); |
447 | 443 | // RoPE values start after the packed 8-bit NoPE values and the
|
448 | 444 | // 32-bit scales
|
449 |
| - const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx; |
450 |
| - kv_cache_16bit[dst_idx] = k_pe[src_idx]; |
| 445 | + const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start; |
| 446 | + // Vectorized store of two 16-bit values, performed as one 32-bit store |
| 447 | + *reinterpret_cast<int32_t*>(&kv_cache_16bit[dst_idx]) = vals; |
451 | 448 | return;
|
452 | 449 | }
|
453 | 450 |
|
454 |
| - // Determine the scale for each chunk of NoPE |
455 |
| - const int16_t tile_idx = threadIdx.x >> 7; |
456 |
| - const int16_t warp_idx = (threadIdx.x & 127) >> 5; |
457 |
| - const int16_t lane_idx = threadIdx.x & 31; |
458 |
| - |
459 |
| - // Load the NoPE element for this thread into registers |
460 |
| - const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x; |
461 |
| - const scalar_t src_val = kv_c[src_idx]; |
462 |
| - |
463 |
| - // Warp-level reduction to find the max absolute value in the warp |
464 |
| - float max_abs = fabsf(src_val); |
| 451 | + // The first two warps handle the NoPE part |
| 452 | + const int8_t warp_idx = threadIdx.x >> 5; |
| 453 | + const int8_t lane_idx = threadIdx.x & 31; |
| 454 | + const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4); |
| 455 | + |
| 456 | + // Each thread handles 8 elements of NoPE |
| 457 | + // Load the NoPE elements for this thread into registers |
| 458 | + const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8); |
| 459 | + // Vectorized load of eight 16-bit values, performed as an int4 load |
| 460 | + const int4 vals_i4 = *reinterpret_cast<const int4*>(&kv_c[src_idx_start]); |
| 461 | + const scalar_t* vals = reinterpret_cast<const scalar_t*>(&vals_i4); |
| 462 | + |
| 463 | + // Max absolute value of this thread's elements |
| 464 | + float max_abs = fmaxf(fmaxf(fmaxf(fabsf(vals[0]), fabsf(vals[1])), |
| 465 | + fmaxf(fabsf(vals[2]), fabsf(vals[3]))), |
| 466 | + fmaxf(fmaxf(fabsf(vals[4]), fabsf(vals[5])), |
| 467 | + fmaxf(fabsf(vals[6]), fabsf(vals[7])))); |
| 468 | + |
| 469 | + // Warp-level reduction to find the max absolute value in each half-warp |
465 | 470 | #pragma unroll
|
466 |
| - for (int offset = 16; offset > 0; offset /= 2) { |
467 |
| -#ifdef USE_ROCM |
468 |
| - max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset)); |
469 |
| -#else |
470 |
| - max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset)); |
471 |
| -#endif |
| 471 | + for (int offset = 8; offset > 0; offset /= 2) { |
| 472 | + max_abs = fmaxf(max_abs, VLLM_SHFL_XOR_SYNC_WIDTH(max_abs, offset, 16)); |
472 | 473 | }
|
473 | 474 |
|
474 |
| - // The first lane of each warp in each tile writes the max_abs of this part |
475 |
| - // of the tile to shared memory |
476 |
| - if (lane_idx == 0) { |
477 |
| - shard_abs_max[tile_idx * 4 + warp_idx] = max_abs; |
478 |
| - } |
479 |
| - __syncthreads(); |
480 |
| - |
481 |
| - // The first lane of the first warp in each tile computes the scale for the |
482 |
| - // tile and writes it to shared memory and to kv_cache |
483 |
| - if (warp_idx == 0 && lane_idx == 0) { |
484 |
| - float4 shard_abs_max_vec = |
485 |
| - reinterpret_cast<float4*>(shard_abs_max)[tile_idx]; |
486 |
| - float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y), |
487 |
| - fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) / |
488 |
| - 448.f; |
489 |
| - |
490 |
| - // Avoid division by zero in `scaled_convert` |
491 |
| - tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN); |
| 475 | + // Compute the scale for the tile |
| 476 | + float tile_scale = max_abs / 448.f; |
| 477 | + |
| 478 | + // The first lane of each half-warp writes the scale to kv_cache |
| 479 | + if ((lane_idx == 0) || (lane_idx == 16)) { |
492 | 480 | float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
|
493 | 481 | const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
|
494 |
| - kv_cache_32bit[dst_idx] = tile_scales[tile_idx]; |
| 482 | + kv_cache_32bit[dst_idx] = tile_scale; |
495 | 483 | }
|
496 | 484 |
|
497 |
| - __syncthreads(); |
| 485 | + // Now all threads in the block scale and write their elements |
| 486 | + // NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes) |
| 487 | + const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8); |
| 488 | + |
| 489 | + uint8_t result[8]; |
| 490 | +#pragma unroll |
| 491 | + for (int i = 0; i < 8; i++) { |
| 492 | + result[i] = |
| 493 | + fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>( |
| 494 | + vals[i], tile_scale); |
| 495 | + } |
498 | 496 |
|
499 |
| - // Now all threads in the block scale and write their element |
500 |
| - const float scale_val = tile_scales[tile_idx]; |
501 |
| - const int64_t dst_idx = dst_idx_start + threadIdx.x; |
502 |
| - kv_cache[dst_idx] = |
503 |
| - fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>( |
504 |
| - src_val, scale_val); |
| 497 | + // Store as aligned 64-bit writes |
| 498 | + *reinterpret_cast<uint64_t*>(&kv_cache[dst_idx_base]) = |
| 499 | + *reinterpret_cast<const uint64_t*>(result); |
505 | 500 | }
|
506 | 501 |
|
507 | 502 | template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
@@ -741,13 +736,12 @@ void concat_and_cache_mla(
|
741 | 736 |
|
742 | 737 | if (kv_cache_dtype == "fp8_ds_mla") {
|
743 | 738 | dim3 grid(num_tokens);
|
744 |
| - // For the NoPE part, each tile of 128 elements is handled by 4 warps |
745 |
| - // (128 threads). There are 4 total tiles, so 16 warps (512 threads). |
746 |
| - // The first thread of the first warp in each tile writes the scale |
747 |
| - // value for the tile. The RoPE part (last 64 elements) is handled |
748 |
| - // by another 2 warps (64 threads). |
749 |
| - // So in total, we use 18 warps (576 threads) per block. |
750 |
| - dim3 block(576); |
| 739 | + // For the NoPE part, each tile of 128 elements is handled by half of one |
| 740 | + // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). |
| 741 | + // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. |
| 742 | + // The RoPE part (last 64 elements) is handled by another 1 warp (32 |
| 743 | + // threads). So in total, we use 3 warps (96 threads) per block. |
| 744 | + dim3 block(96); |
751 | 745 | DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
752 | 746 | CALL_CONCAT_AND_CACHE_DS_MLA);
|
753 | 747 | } else {
|
|
0 commit comments