5
5
#include " cuda_utils.h"
6
6
#include " cuda_compat.h"
7
7
#include " dispatch_utils.h"
8
+ #include " quantization/vectorization_utils.cuh"
8
9
9
10
#ifdef USE_ROCM
10
11
#include " quantization/fp8/amd/quant_utils.cuh"
@@ -261,14 +262,26 @@ __global__ void reshape_and_cache_kernel(
261
262
}
262
263
}
263
264
265
+ // Used by vectorization_utils to copy/convert one element
266
+ template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
267
+ struct CopyWithScaleOp {
268
+ float scale;
269
+
270
+ __device__ __forceinline__ void operator ()(OutT& dst, const InT src) const {
271
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
272
+ dst = static_cast <OutT>(src);
273
+ } else {
274
+ dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
275
+ }
276
+ }
277
+ };
278
+
264
279
template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
265
280
__global__ void reshape_and_cache_flash_kernel (
266
281
const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
267
282
const scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
268
- cache_t * __restrict__ key_cache, // [num_blocks, block_size, num_heads,
269
- // head_size]
270
- cache_t * __restrict__ value_cache, // [num_blocks, block_size, num_heads,
271
- // head_size]
283
+ cache_t * __restrict__ key_cache, // NHD or HND, shape see comments below
284
+ cache_t * __restrict__ value_cache, // same above
272
285
const int64_t * __restrict__ slot_mapping, // [num_tokens]
273
286
const int64_t block_stride, const int64_t page_stride,
274
287
const int64_t head_stride, const int64_t key_stride,
@@ -282,25 +295,58 @@ __global__ void reshape_and_cache_flash_kernel(
282
295
}
283
296
const int64_t block_idx = slot_idx / block_size;
284
297
const int64_t block_offset = slot_idx % block_size;
285
- const int n = num_heads * head_size;
286
- for (int i = threadIdx .x ; i < n; i += blockDim .x ) {
287
- const int64_t src_key_idx = token_idx * key_stride + i;
288
- const int64_t src_value_idx = token_idx * value_stride + i;
289
- const int head_idx = i / head_size;
290
- const int head_offset = i % head_size;
291
- const int64_t tgt_key_value_idx = block_idx * block_stride +
292
- block_offset * page_stride +
293
- head_idx * head_stride + head_offset;
294
- scalar_t tgt_key = key[src_key_idx];
295
- scalar_t tgt_value = value[src_value_idx];
296
- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
297
- key_cache[tgt_key_value_idx] = tgt_key;
298
- value_cache[tgt_key_value_idx] = tgt_value;
299
- } else {
300
- key_cache[tgt_key_value_idx] =
301
- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, *k_scale);
302
- value_cache[tgt_key_value_idx] =
303
- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, *v_scale);
298
+ const int n_elems = num_heads * head_size;
299
+
300
+ // pointers to the beginning of the source row for this token.
301
+ const scalar_t * __restrict__ key_src = key + token_idx * key_stride;
302
+ const scalar_t * __restrict__ value_src = value + token_idx * value_stride;
303
+
304
+ // find the start position inside the kv-cache for this token.
305
+ cache_t * __restrict__ key_dst =
306
+ key_cache + block_idx * block_stride + block_offset * page_stride;
307
+ cache_t * __restrict__ value_dst =
308
+ value_cache + block_idx * block_stride + block_offset * page_stride;
309
+
310
+ // this is true for the NHD layout where `head_stride == head_size`
311
+ const bool is_contiguous_heads = (head_stride == head_size);
312
+
313
+ float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *k_scale;
314
+ float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *v_scale;
315
+ constexpr int VEC_SIZE = (sizeof (scalar_t ) == 2 ) ? 8 : 4 ;
316
+ CopyWithScaleOp<cache_t , scalar_t , kv_dt> k_op{k_scale_val};
317
+ CopyWithScaleOp<cache_t , scalar_t , kv_dt> v_op{v_scale_val};
318
+ if (is_contiguous_heads) {
319
+ // NHD layout
320
+ // kv cache: [num_blocks, block_size, num_heads, head_size]
321
+ vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx .x ,
322
+ blockDim .x , k_op);
323
+
324
+ vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
325
+ threadIdx .x , blockDim .x , v_op);
326
+
327
+ } else {
328
+ // HND layout: heads are strided, but each head_size segment is contiguous
329
+ // kv cache: [num_blocks, num_heads, block_size, head_size]
330
+ const int lane = threadIdx .x & 31 ; // 0..31 within warp
331
+ const int warp_id = threadIdx .x >> 5 ; // warp index within block
332
+ const int warps_per_block = blockDim .x >> 5 ;
333
+
334
+ for (int head = warp_id; head < num_heads; head += warps_per_block) {
335
+ const scalar_t * __restrict__ k_src_h = key_src + head * head_size;
336
+ const scalar_t * __restrict__ v_src_h = value_src + head * head_size;
337
+
338
+ cache_t * __restrict__ k_dst_h =
339
+ key_dst + static_cast <int64_t >(head) * head_stride;
340
+ cache_t * __restrict__ v_dst_h =
341
+ value_dst + static_cast <int64_t >(head) * head_stride;
342
+
343
+ // within each head, let the 32 threads of the warp perform the vector
344
+ // copy
345
+ vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32 ,
346
+ k_op);
347
+
348
+ vectorize_with_alignment<VEC_SIZE>(v_src_h, v_dst_h, head_size, lane, 32 ,
349
+ v_op);
304
350
}
305
351
}
306
352
}
0 commit comments