Skip to content

Commit 808ee6e

Browse files
authored
[Inference/Feat] Feat quant kvcache step2 (#5674)
1 parent 8ccb671 commit 808ee6e

File tree

4 files changed

+207
-70
lines changed

4 files changed

+207
-70
lines changed

extensions/csrc/funcs/cast_functor.h

Lines changed: 97 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#endif
1010

1111
#include <assert.h>
12+
#include <stdint.h>
1213

1314
#include <functional>
1415

@@ -175,6 +176,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, uint16_t, DEVICE, STMTS_WRAPPER({
175176
return res.x;
176177
}))
177178

179+
// half raw -> fp8
180+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({
181+
__half_raw tmp;
182+
tmp.x = val;
183+
__nv_fp8_storage_t res =
184+
__nv_cvt_halfraw_to_fp8(
185+
tmp, __NV_SATFINITE, __NV_E5M2);
186+
return static_cast<uint8_t>(res);
187+
}))
188+
178189
// fp8x2 -> half2 raw
179190
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint32_t, DEVICE, STMTS_WRAPPER({
180191
union {
@@ -222,6 +233,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, half, DEVICE, STMTS_WRAPPER({
222233
return half(res);
223234
}))
224235

236+
// half -> fp8
237+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, uint8_t, DEVICE, STMTS_WRAPPER({
238+
__half_raw tmp(val);
239+
__nv_fp8_storage_t res =
240+
__nv_cvt_halfraw_to_fp8(
241+
tmp, __NV_SATFINITE, __NV_E5M2);
242+
return static_cast<uint8_t>(res);
243+
}))
244+
225245
// fp8x2 -> half2
226246
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({
227247
__half2_raw res =
@@ -230,6 +250,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({
230250
return half2(res);
231251
}))
232252

253+
// half2 -> fp8x2
254+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, uint16_t, DEVICE, STMTS_WRAPPER({
255+
__half2_raw tmp(val);
256+
__nv_fp8x2_storage_t res =
257+
__nv_cvt_halfraw2_to_fp8x2(
258+
tmp, __NV_SATFINITE, __NV_E5M2);
259+
return static_cast<uint16_t>(res);
260+
}))
261+
233262
// fp8x4 -> half4
234263
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
235264
uint32_t, dtype::half4, DEVICE, STMTS_WRAPPER({
@@ -242,6 +271,20 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
242271
return res;
243272
}))
244273

274+
// half4 -> fp8x4
275+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
276+
dtype::half4, uint32_t, DEVICE, STMTS_WRAPPER({
277+
half2 x, y;
278+
x = val.x;
279+
y = val.y;
280+
uint16_t lo, hi;
281+
lo = CastFunctor<half2, uint16_t>()(x);
282+
hi = CastFunctor<half2, uint16_t>()(y);
283+
uint32_t res;
284+
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(lo), "h"(hi));
285+
return res;
286+
}))
287+
245288
// fp8x8 -> half8
246289
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
247290
uint2, dtype::half8, DEVICE, STMTS_WRAPPER({
@@ -314,6 +357,14 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
314357
return res;
315358
}))
316359

360+
// float -> fp8
361+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({
362+
__nv_fp8_storage_t res =
363+
__nv_cvt_float_to_fp8(
364+
val, __NV_SATFINITE, __NV_E5M2);
365+
return static_cast<uint8_t>(res);
366+
}))
367+
317368
// fp8x2 -> float2
318369
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
319370
uint16_t, float2, DEVICE, STMTS_WRAPPER({
@@ -328,6 +379,28 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
328379
return make_float2(lof, hif);
329380
}))
330381

382+
// float2 -> fp8x2
383+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
384+
float2, uint16_t, DEVICE, STMTS_WRAPPER({
385+
uint16_t tmp1 =
386+
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x));
387+
uint16_t tmp2 =
388+
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y));
389+
uint16_t res = (tmp1 << 8U) | tmp2;
390+
return res;
391+
}))
392+
393+
// float4 -> fp8x4
394+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
395+
uint32_t a, b, c, d;
396+
a = CastFunctor<float, uint8_t>()(val.x);
397+
b = CastFunctor<float, uint8_t>()(val.y);
398+
c = CastFunctor<float, uint8_t>()(val.z);
399+
d = CastFunctor<float, uint8_t>()(val.w);
400+
return (a << 24U) | (b << 16U) |
401+
(c << 8U) | d;
402+
}))
403+
331404
// fp8x4 -> float4_
332405
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
333406
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
@@ -338,6 +411,14 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
338411
return res;
339412
}))
340413

414+
// fp8x4 -> float4
415+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
416+
uint32_t, float4, DEVICE, STMTS_WRAPPER({
417+
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
418+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
419+
return res;
420+
}))
421+
341422
// fp8x8 -> float8_
342423
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
343424
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
@@ -352,16 +433,6 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
352433
return res;
353434
}))
354435

355-
// half -> fp8
356-
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({
357-
__half_raw tmp;
358-
tmp.x = val;
359-
__nv_fp8_storage_t res =
360-
__nv_cvt_halfraw_to_fp8(
361-
tmp, __NV_SATFINITE, __NV_E5M2);
362-
return static_cast<uint8_t>(res);
363-
}))
364-
365436
// bf16 -> fp8
366437
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE,
367438
STMTS_WRAPPER({
@@ -376,19 +447,24 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE,
376447
#endif
377448
}))
378449

379-
// float -> fp8
380-
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({
381-
__nv_fp8_storage_t res =
382-
__nv_cvt_float_to_fp8(
383-
val, __NV_SATFINITE, __NV_E5M2);
384-
return static_cast<uint8_t>(res);
385-
}))
450+
// bf162 -> fp8x2
451+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
452+
__nv_bfloat162, uint16_t, DEVICE, STMTS_WRAPPER({
453+
uint16_t a =
454+
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x));
455+
uint16_t b =
456+
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y));
457+
return (a << 8U) | b;
458+
}))
386459

387-
// fp8x4 -> float4
460+
// bf164 -> fp8x4
388461
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
389-
uint32_t, float4, DEVICE, STMTS_WRAPPER({
390-
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
391-
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
462+
dtype::bfloat164, uint32_t, DEVICE, STMTS_WRAPPER({
463+
uint32_t res;
464+
uint16_t a, b;
465+
a = CastFunctor<__nv_bfloat162, uint16_t>()(val.x);
466+
b = CastFunctor<__nv_bfloat162, uint16_t>()(val.y);
467+
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(a), "h"(b));
392468
return res;
393469
}))
394470

extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu

Lines changed: 82 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
#include "utils/vec_copy.h"
55
#include "common/micros.h"
66

7-
using colossalAI::cuda::utils::copy_vector;
87
using colossalAI::cuda::utils::get_vec_size;
8+
using colossalAI::cuda::utils::copy;
9+
using colossalAI::funcs::CastFunctor;
910

1011

11-
template<typename scalar_t, bool Aligned, int VecSize>
12+
template<typename T, typename CacheT, bool Aligned, int VecSize>
1213
__global__ void context_kv_cache_memcpy_kernel(
13-
const scalar_t* __restrict__ key,
14-
const scalar_t* __restrict__ value,
15-
scalar_t* __restrict__ key_cache,
16-
scalar_t* __restrict__ value_cache,
14+
const T* __restrict__ key,
15+
const T* __restrict__ value,
16+
CacheT* __restrict__ key_cache,
17+
CacheT* __restrict__ value_cache,
1718
const int* __restrict__ sequence_lengths,
1819
const int* __restrict__ cu_seqlens,
1920
const int* __restrict__ block_tables,
@@ -54,8 +55,8 @@ __global__ void context_kv_cache_memcpy_kernel(
5455
+ head_id * block_size * head_dim
5556
+ block_offset * head_dim + head_offset;
5657

57-
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
58-
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
58+
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id);
59+
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id);
5960
}
6061

6162
// tail process
@@ -69,22 +70,22 @@ __global__ void context_kv_cache_memcpy_kernel(
6970
+ head_id * block_size * head_dim
7071
+ block_offset * head_dim + head_offset;
7172

72-
key_cache[target_id] = key[key_src_id];
73-
value_cache[target_id] = value[value_src_id];
73+
key_cache[target_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
74+
value_cache[target_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
7475
}
7576
}
7677

7778
}
7879

79-
template<typename scalar_t>
80+
template<typename T, typename CacheT>
8081
void apply_context_kv_cache_memcpy(
81-
at::Tensor& key, // [num_tokens, head_num, head_dim]
82-
at::Tensor& value, // [num_tokens, head_num, head_dim]
83-
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
84-
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
85-
at::Tensor& sequence_lengths, // [batch_size]
86-
at::Tensor& cu_seqlens, // [batch_size + 1]
87-
at::Tensor& block_tables, // [batch_size, max_seq_len]
82+
torch::Tensor& key, // [num_tokens, head_num, head_dim]
83+
torch::Tensor& value, // [num_tokens, head_num, head_dim]
84+
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
85+
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
86+
torch::Tensor& sequence_lengths, // [batch_size]
87+
torch::Tensor& cu_seqlens, // [batch_size + 1]
88+
torch::Tensor& block_tables, // [batch_size, max_seq_len]
8889
int max_seq_len_in_batch)
8990
{
9091
int num_tokens = key.size(0);
@@ -97,7 +98,7 @@ void apply_context_kv_cache_memcpy(
9798
int64_t value_stride = value.stride(0);
9899
int block_table_stride = block_tables.stride(0);
99100

100-
int vec_size = get_vec_size<scalar_t>(key);
101+
int vec_size = get_vec_size<T>(key);
101102

102103
bool aligned = true;
103104
if (head_dim % vec_size != 0) {
@@ -112,11 +113,11 @@ void apply_context_kv_cache_memcpy(
112113

113114
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
114115
do { \
115-
context_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
116-
key.data_ptr<scalar_t>(), \
117-
value.data_ptr<scalar_t>(), \
118-
key_cache.data_ptr<scalar_t>(), \
119-
value_cache.data_ptr<scalar_t>(), \
116+
context_kv_cache_memcpy_kernel<T, CacheT, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
117+
reinterpret_cast<T*>(key.data_ptr()), \
118+
reinterpret_cast<T*>(value.data_ptr()), \
119+
reinterpret_cast<CacheT*>(key_cache.data_ptr()), \
120+
reinterpret_cast<CacheT*>(value_cache.data_ptr()), \
120121
sequence_lengths.data_ptr<int>(), \
121122
cu_seqlens.data_ptr<int>(), \
122123
block_tables.data_ptr<int>(), \
@@ -161,26 +162,63 @@ void apply_context_kv_cache_memcpy(
161162
}
162163

163164
void context_kv_cache_memcpy(
164-
at::Tensor& key, // [num_tokens, head_num, head_dim]
165-
at::Tensor& value, // [num_tokens, head_num, head_dim]
166-
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
167-
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
168-
at::Tensor& sequence_lengths, // [batch_size]
169-
at::Tensor& cu_seqlens, // [batch_size + 1]
170-
at::Tensor& block_tables, // [batch_size, max_seq_len]
165+
torch::Tensor& key, // [num_tokens, head_num, head_dim]
166+
torch::Tensor& value, // [num_tokens, head_num, head_dim]
167+
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
168+
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
169+
torch::Tensor& sequence_lengths, // [batch_size]
170+
torch::Tensor& cu_seqlens, // [batch_size + 1]
171+
torch::Tensor& block_tables, // [batch_size, max_seq_len]
171172
int max_seq_len_in_batch)
172173
{
173-
DISPATCH_FLOAT_HALF_AND_BFLOAT(
174-
key.scalar_type(),
175-
"context_kv_cache_memcpy",
176-
apply_context_kv_cache_memcpy<scalar_t>(
177-
key,
178-
value,
179-
key_cache,
180-
value_cache,
181-
sequence_lengths,
182-
cu_seqlens,
183-
block_tables,
184-
max_seq_len_in_batch
185-
);)
174+
175+
TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16,
176+
"Dtype of key should be float, half or bfloat16!");
177+
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(),
178+
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
179+
180+
181+
#define _(T, CacheT) \
182+
apply_context_kv_cache_memcpy<T, CacheT>( \
183+
key, \
184+
value, \
185+
key_cache, \
186+
value_cache, \
187+
sequence_lengths, \
188+
cu_seqlens, \
189+
block_tables, \
190+
max_seq_len_in_batch \
191+
)
192+
193+
if(key_cache.scalar_type() == at::ScalarType::Byte)
194+
{
195+
switch (key.scalar_type())
196+
{
197+
case at::ScalarType::Float:
198+
_(float, uint8_t);
199+
break;
200+
case at::ScalarType::Half:
201+
_(half, uint8_t);
202+
break;
203+
case at::ScalarType::BFloat16:
204+
_(__nv_bfloat16, uint8_t);
205+
break;
206+
}
207+
}
208+
else
209+
{
210+
switch (key.scalar_type())
211+
{
212+
case at::ScalarType::Float:
213+
_(float, float);
214+
break;
215+
case at::ScalarType::Half:
216+
_(half, half);
217+
break;
218+
case at::ScalarType::BFloat16:
219+
_(__nv_bfloat16, __nv_bfloat16);
220+
break;
221+
}
222+
}
223+
#undef _
186224
}

extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ void flash_decoding_attention(
372372

373373
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
374374
"Dtype of query should be float, half or bfloat16!");
375-
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key_cache.scalar_type(),
375+
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
376376
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
377377

378378
if(key_cache.scalar_type() == at::ScalarType::Byte)

0 commit comments

Comments
 (0)