4
4
#include " utils/vec_copy.h"
5
5
#include " common/micros.h"
6
6
7
- using colossalAI::cuda::utils::copy_vector;
8
7
using colossalAI::cuda::utils::get_vec_size;
8
+ using colossalAI::cuda::utils::copy;
9
+ using colossalAI::funcs::CastFunctor;
9
10
10
11
11
- template <typename scalar_t , bool Aligned, int VecSize>
12
+ template <typename T, typename CacheT , bool Aligned, int VecSize>
12
13
__global__ void context_kv_cache_memcpy_kernel (
13
- const scalar_t * __restrict__ key,
14
- const scalar_t * __restrict__ value,
15
- scalar_t * __restrict__ key_cache,
16
- scalar_t * __restrict__ value_cache,
14
+ const T * __restrict__ key,
15
+ const T * __restrict__ value,
16
+ CacheT * __restrict__ key_cache,
17
+ CacheT * __restrict__ value_cache,
17
18
const int * __restrict__ sequence_lengths,
18
19
const int * __restrict__ cu_seqlens,
19
20
const int * __restrict__ block_tables,
@@ -54,8 +55,8 @@ __global__ void context_kv_cache_memcpy_kernel(
54
55
+ head_id * block_size * head_dim
55
56
+ block_offset * head_dim + head_offset;
56
57
57
- copy_vector< scalar_t , VecSize>(key_cache + target_id, key + key_src_id );
58
- copy_vector< scalar_t , VecSize>(value_cache + target_id, value + value_src_id );
58
+ copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id );
59
+ copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id );
59
60
}
60
61
61
62
// tail process
@@ -69,22 +70,22 @@ __global__ void context_kv_cache_memcpy_kernel(
69
70
+ head_id * block_size * head_dim
70
71
+ block_offset * head_dim + head_offset;
71
72
72
- key_cache[target_id] = key[key_src_id];
73
- value_cache[target_id] = value[value_src_id];
73
+ key_cache[target_id] = CastFunctor<T, CacheT>()( key[key_src_id]) ;
74
+ value_cache[target_id] = CastFunctor<T, CacheT>()( value[value_src_id]) ;
74
75
}
75
76
}
76
77
77
78
}
78
79
79
- template <typename scalar_t >
80
+ template <typename T, typename CacheT >
80
81
void apply_context_kv_cache_memcpy (
81
- at ::Tensor& key, // [num_tokens, head_num, head_dim]
82
- at ::Tensor& value, // [num_tokens, head_num, head_dim]
83
- at ::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
84
- at ::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
85
- at ::Tensor& sequence_lengths, // [batch_size]
86
- at ::Tensor& cu_seqlens, // [batch_size + 1]
87
- at ::Tensor& block_tables, // [batch_size, max_seq_len]
82
+ torch ::Tensor& key, // [num_tokens, head_num, head_dim]
83
+ torch ::Tensor& value, // [num_tokens, head_num, head_dim]
84
+ torch ::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
85
+ torch ::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
86
+ torch ::Tensor& sequence_lengths, // [batch_size]
87
+ torch ::Tensor& cu_seqlens, // [batch_size + 1]
88
+ torch ::Tensor& block_tables, // [batch_size, max_seq_len]
88
89
int max_seq_len_in_batch)
89
90
{
90
91
int num_tokens = key.size (0 );
@@ -97,7 +98,7 @@ void apply_context_kv_cache_memcpy(
97
98
int64_t value_stride = value.stride (0 );
98
99
int block_table_stride = block_tables.stride (0 );
99
100
100
- int vec_size = get_vec_size<scalar_t >(key);
101
+ int vec_size = get_vec_size<T >(key);
101
102
102
103
bool aligned = true ;
103
104
if (head_dim % vec_size != 0 ) {
@@ -112,11 +113,11 @@ void apply_context_kv_cache_memcpy(
112
113
113
114
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH (__aligned, __vec_size ) \
114
115
do { \
115
- context_kv_cache_memcpy_kernel<scalar_t , __aligned, __vec_size><<<grid, block, 0 , stream>>> ( \
116
- key.data_ptr < scalar_t >(), \
117
- value.data_ptr < scalar_t >(), \
118
- key_cache.data_ptr < scalar_t >(), \
119
- value_cache.data_ptr < scalar_t >(), \
116
+ context_kv_cache_memcpy_kernel<T, CacheT , __aligned, __vec_size><<<grid, block, 0 , stream>>> ( \
117
+ reinterpret_cast <T*>( key.data_ptr ()), \
118
+ reinterpret_cast <T*>( value.data_ptr ()), \
119
+ reinterpret_cast <CacheT*>( key_cache.data_ptr ()), \
120
+ reinterpret_cast <CacheT*>( value_cache.data_ptr ()), \
120
121
sequence_lengths.data_ptr <int >(), \
121
122
cu_seqlens.data_ptr <int >(), \
122
123
block_tables.data_ptr <int >(), \
@@ -161,26 +162,63 @@ void apply_context_kv_cache_memcpy(
161
162
}
162
163
163
164
void context_kv_cache_memcpy (
164
- at ::Tensor& key, // [num_tokens, head_num, head_dim]
165
- at ::Tensor& value, // [num_tokens, head_num, head_dim]
166
- at ::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
167
- at ::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
168
- at ::Tensor& sequence_lengths, // [batch_size]
169
- at ::Tensor& cu_seqlens, // [batch_size + 1]
170
- at ::Tensor& block_tables, // [batch_size, max_seq_len]
165
+ torch ::Tensor& key, // [num_tokens, head_num, head_dim]
166
+ torch ::Tensor& value, // [num_tokens, head_num, head_dim]
167
+ torch ::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
168
+ torch ::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
169
+ torch ::Tensor& sequence_lengths, // [batch_size]
170
+ torch ::Tensor& cu_seqlens, // [batch_size + 1]
171
+ torch ::Tensor& block_tables, // [batch_size, max_seq_len]
171
172
int max_seq_len_in_batch)
172
173
{
173
- DISPATCH_FLOAT_HALF_AND_BFLOAT (
174
- key.scalar_type (),
175
- " context_kv_cache_memcpy" ,
176
- apply_context_kv_cache_memcpy<scalar_t >(
177
- key,
178
- value,
179
- key_cache,
180
- value_cache,
181
- sequence_lengths,
182
- cu_seqlens,
183
- block_tables,
184
- max_seq_len_in_batch
185
- );)
174
+
175
+ TORCH_CHECK (key.scalar_type () == at::ScalarType::Float || key.scalar_type () == at::ScalarType::Half || key.scalar_type () == at::ScalarType::BFloat16,
176
+ " Dtype of key should be float, half or bfloat16!" );
177
+ TORCH_CHECK (key_cache.scalar_type () == at::ScalarType::Byte || key_cache.scalar_type () == key.scalar_type (),
178
+ " Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!" );
179
+
180
+
181
+ #define _ (T, CacheT ) \
182
+ apply_context_kv_cache_memcpy<T, CacheT>( \
183
+ key, \
184
+ value, \
185
+ key_cache, \
186
+ value_cache, \
187
+ sequence_lengths, \
188
+ cu_seqlens, \
189
+ block_tables, \
190
+ max_seq_len_in_batch \
191
+ )
192
+
193
+ if (key_cache.scalar_type () == at::ScalarType::Byte)
194
+ {
195
+ switch (key.scalar_type ())
196
+ {
197
+ case at::ScalarType::Float:
198
+ _ (float , uint8_t );
199
+ break ;
200
+ case at::ScalarType::Half:
201
+ _ (half, uint8_t );
202
+ break ;
203
+ case at::ScalarType::BFloat16:
204
+ _ (__nv_bfloat16, uint8_t );
205
+ break ;
206
+ }
207
+ }
208
+ else
209
+ {
210
+ switch (key.scalar_type ())
211
+ {
212
+ case at::ScalarType::Float:
213
+ _ (float , float );
214
+ break ;
215
+ case at::ScalarType::Half:
216
+ _ (half, half);
217
+ break ;
218
+ case at::ScalarType::BFloat16:
219
+ _ (__nv_bfloat16, __nv_bfloat16);
220
+ break ;
221
+ }
222
+ }
223
+ #undef _
186
224
}
0 commit comments