16
16
17
17
#include < algorithm>
18
18
#include < cassert>
19
- #include < map>
20
- #include < vector>
19
+ #include < cfloat> // FLT_MIN
21
20
22
21
#ifdef USE_ROCM
23
22
#include < hip/hip_bf16.h>
@@ -209,6 +208,20 @@ void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
209
208
210
209
namespace vllm {
211
210
211
+ // Used to copy/convert one element
212
+ template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
213
+ struct CopyWithScaleOp {
214
+ float scale;
215
+
216
+ __device__ __forceinline__ void operator ()(OutT& dst, const InT src) const {
217
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
218
+ dst = static_cast <OutT>(src);
219
+ } else {
220
+ dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
221
+ }
222
+ }
223
+ };
224
+
212
225
template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
213
226
__global__ void reshape_and_cache_kernel (
214
227
const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
@@ -224,58 +237,50 @@ __global__ void reshape_and_cache_kernel(
224
237
const int64_t token_idx = blockIdx .x ;
225
238
const int64_t slot_idx = slot_mapping[token_idx];
226
239
if (slot_idx < 0 ) {
227
- // Padding token that should be ignored.
228
240
return ;
229
241
}
230
242
231
243
const int64_t block_idx = slot_idx / block_size;
232
244
const int64_t block_offset = slot_idx % block_size;
245
+ const int h_block_count = head_size / x; // head_size//x
233
246
234
- const int n = num_heads * head_size;
235
- for (int i = threadIdx .x ; i < n; i += blockDim .x ) {
236
- const int64_t src_key_idx = token_idx * key_stride + i;
237
- const int64_t src_value_idx = token_idx * value_stride + i;
238
-
239
- const int head_idx = i / head_size;
240
- const int head_offset = i % head_size;
241
- const int x_idx = head_offset / x;
242
- const int x_offset = head_offset % x;
243
-
244
- const int64_t tgt_key_idx =
245
- block_idx * num_heads * (head_size / x) * block_size * x +
246
- head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
247
- block_offset * x + x_offset;
248
- const int64_t tgt_value_idx =
249
- block_idx * num_heads * head_size * block_size +
250
- head_idx * head_size * block_size + head_offset * block_size +
251
- block_offset;
252
- scalar_t tgt_key = key[src_key_idx];
253
- scalar_t tgt_value = value[src_value_idx];
254
- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
255
- key_cache[tgt_key_idx] = tgt_key;
256
- value_cache[tgt_value_idx] = tgt_value;
257
- } else {
258
- key_cache[tgt_key_idx] =
259
- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, *k_scale);
260
- value_cache[tgt_value_idx] =
261
- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, *v_scale);
262
- }
247
+ const int h_block_idx = threadIdx .x ;
248
+ if (h_block_idx >= num_heads * h_block_count) {
249
+ return ;
263
250
}
264
- }
265
251
266
- // Used by vectorization_utils to copy/convert one element
267
- template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
268
- struct CopyWithScaleOp {
269
- float scale;
252
+ const int head_idx = h_block_idx / h_block_count;
253
+ const int h_block = h_block_idx % h_block_count;
270
254
271
- __device__ __forceinline__ void operator ()(OutT& dst, const InT src) const {
272
- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
273
- dst = static_cast <OutT>(src);
274
- } else {
275
- dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
276
- }
255
+ const scalar_t * __restrict__ key_src =
256
+ key + token_idx * key_stride + head_idx * head_size + h_block * x;
257
+ const int64_t src_value_start =
258
+ token_idx * value_stride + head_idx * head_size + h_block * x;
259
+
260
+ cache_t * __restrict__ key_dst =
261
+ key_cache + block_idx * num_heads * h_block_count * block_size * x +
262
+ head_idx * h_block_count * block_size * x + h_block * block_size * x +
263
+ block_offset * x;
264
+ const int64_t tgt_value_start =
265
+ block_idx * num_heads * h_block_count * x * block_size +
266
+ head_idx * h_block_count * x * block_size + h_block * x * block_size +
267
+ block_offset;
268
+
269
+ constexpr int VEC_SIZE = (sizeof (scalar_t ) == 2 ) ? 8 : 4 ;
270
+ float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *k_scale;
271
+ CopyWithScaleOp<cache_t , scalar_t , kv_dt> k_op{k_scale_val};
272
+ float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *v_scale;
273
+ CopyWithScaleOp<cache_t , scalar_t , kv_dt> v_op{v_scale_val};
274
+
275
+ vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, x, 0 , 1 , k_op);
276
+
277
+ const scalar_t * __restrict__ value_src = value + src_value_start;
278
+ cache_t * __restrict__ value_dst = value_cache + tgt_value_start;
279
+ #pragma unroll
280
+ for (int i = 0 ; i < x; i++) {
281
+ v_op (value_dst[i * block_size], value_src[i]);
277
282
}
278
- };
283
+ }
279
284
280
285
template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
281
286
__global__ void reshape_and_cache_flash_kernel (
@@ -601,9 +606,10 @@ void reshape_and_cache(
601
606
602
607
int key_stride = key.stride (0 );
603
608
int value_stride = value.stride (0 );
609
+ int head_div_x = head_size / x;
604
610
605
611
dim3 grid (num_tokens);
606
- dim3 block (std::min (num_heads * head_size , 512 ));
612
+ dim3 block (std::min (num_heads * head_div_x , 512 ));
607
613
const at::cuda::OptionalCUDAGuard device_guard (device_of (key));
608
614
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
609
615
0 commit comments