Skip to content

Commit 8ccb671

Browse files
authored
[Inference/Feat] Add kvcache quantization support for FlashDecoding (#5656)
1 parent 5be590b commit 8ccb671

File tree

5 files changed

+480
-172
lines changed

5 files changed

+480
-172
lines changed

extensions/csrc/common/vec_type_traits.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cuda_fp16.h>
66
#endif
77

8+
#include <ATen/ATen.h>
89
#include <stdint.h>
910

1011
#include "common/data_type.h"
@@ -27,6 +28,7 @@ struct FloatVecTypeTrait {};
2728
VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T)
2829

2930
#if defined(COLOSSAL_WITH_CUDA)
31+
3032
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16)
3133
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162)
3234
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2)
@@ -35,18 +37,19 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half)
3537
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2, half2)
3638
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2)
3739
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4)
38-
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
39-
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
40-
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
41-
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half)
42-
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2)
43-
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2)
40+
41+
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, uint16_t)
42+
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, uint32_t)
43+
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, uint2)
4444
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162);
4545
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, dtype::bfloat164);
4646
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, dtype::bfloat168);
4747
VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2);
4848
VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4);
4949
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8);
50+
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
51+
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
52+
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
5053
#endif /* defined(COLOSSAL_WITH_CUDA) */
5154

5255
#undef VEC_TYPE_TRAITS_SPECIALIZATION

0 commit comments

Comments
 (0)