Skip to content

Commit ef8e4ff

Browse files
authored
[Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680)
1 parent 5cd75ce commit ef8e4ff

File tree

7 files changed

+226
-125
lines changed

7 files changed

+226
-125
lines changed

extensions/csrc/common/mp_type_traits.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
#include "micros.h"
66

7+
#if defined(COLOSSAL_WITH_CUDA)
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp16.h>
10+
#endif
11+
712
namespace colossalAI {
813
namespace common {
914

@@ -27,6 +32,18 @@ struct MPTypeTrait<at::BFloat16> {
2732
using Type = float;
2833
};
2934

35+
#if defined(COLOSSAL_WITH_CUDA)
36+
template <>
37+
struct MPTypeTrait<half> {
38+
using Type = float;
39+
};
40+
41+
template <>
42+
struct MPTypeTrait<__nv_bfloat16> {
43+
using Type = float;
44+
};
45+
#endif
46+
3047
template <bool high_precision, typename T>
3148
struct ScalarTypeTrait {
3249
using Type =

extensions/csrc/funcs/binary_functor.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
5656
typename T)
5757

5858
#if defined(COLOSSAL_WITH_CUDA)
59+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus,
60+
DEVICE, STMTS_WRAPPER({
61+
return __hsub(lhs, rhs);
62+
}))
63+
5964
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
6065
DEVICE, STMTS_WRAPPER({
6166
return __hadd(lhs, rhs);
@@ -71,6 +76,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
7176
DEVICE, STMTS_WRAPPER({
7277
return __hadd(lhs, rhs);
7378
}))
79+
80+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
81+
__nv_bfloat16, BinaryOpType::kMinus,
82+
DEVICE, STMTS_WRAPPER({
83+
return __hsub(lhs, rhs);
84+
}))
85+
7486
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
7587
__nv_bfloat162, BinaryOpType::kAdd,
7688
DEVICE, STMTS_WRAPPER({
@@ -82,6 +94,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
8294
STMTS_WRAPPER({
8395
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
8496
}))
97+
98+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
99+
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE,
100+
STMTS_WRAPPER({
101+
return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs));
102+
}))
103+
85104
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
86105
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
87106
STMTS_WRAPPER({

extensions/csrc/funcs/cast_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
9494
STMTS_WRAPPER({
9595
return __float2bfloat16_rn(val);
9696
}))
97+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE,
98+
STMTS_WRAPPER({
99+
return __bfloat162float(val);
100+
}))
97101
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
98102
STMTS_WRAPPER({
99103
dtype::bfloat164 dst;

extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,6 @@ void context_kv_cache_memcpy(
192192
int max_seq_len_in_batch)
193193
{
194194

195-
TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16,
196-
"Dtype of key should be float, half or bfloat16!");
197-
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(),
198-
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
199-
200-
201195
#define _(T, CacheT) \
202196
apply_context_kv_cache_memcpy<T, CacheT>( \
203197
key, \

extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,12 +380,6 @@ void flash_decoding_attention(
380380
const c10::optional<torch::Tensor>& alibi_slopes,
381381
float scale) {
382382

383-
384-
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
385-
"Dtype of query should be float, half or bfloat16!");
386-
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
387-
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
388-
389383
if(key_cache.scalar_type() == at::ScalarType::Byte)
390384
{
391385
switch (query.scalar_type()) {

0 commit comments

Comments
 (0)