@@ -13,177 +13,162 @@ namespace op::reshape_and_cache::metax {
1313
1414using Fp8KVCacheDataType = op::paged_attention_v2::vllm::Fp8KVCacheDataType;
1515
16-
17-
1816// Used by vectorization_utils to copy/convert one element
1917template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
2018struct CopyWithScaleOp {
21- float scale;
22-
23- __device__ __forceinline__ void operator ()(OutT& dst, const InT src) const {
24- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
25- dst = static_cast <OutT>(src);
26- } else {
27- // dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
28- assert (false );
19+ float scale;
20+
21+ __device__ __forceinline__ void operator ()(OutT &dst, const InT src) const {
22+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
23+ dst = static_cast <OutT>(src);
24+ } else {
25+ // dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
26+ assert (false );
27+ }
2928 }
30- }
3129};
3230
33-
34-
35-
3631// Vectorization containers
3732template <typename scalar_t , size_t vec_size>
3833struct __align__ (vec_size * sizeof (scalar_t )) vec_n_t {
39- scalar_t val[vec_size];
34+ scalar_t val[vec_size];
4035};
4136
42-
4337template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
4438struct DefaultVecOp {
45- ScaOp scalar_op;
39+ ScaOp scalar_op;
4640
47- __device__ __forceinline__ void operator ()(
48- vec_n_t <OutT, VEC_SIZE>& dst, const vec_n_t <InT, VEC_SIZE>& src) const {
41+ __device__ __forceinline__ void operator ()(
42+ vec_n_t <OutT, VEC_SIZE> & dst, const vec_n_t <InT, VEC_SIZE> & src) const {
4943#pragma unroll
50- for (int i = 0 ; i < VEC_SIZE; ++i) {
51- scalar_op (dst.val [i], src.val [i]);
44+ for (int i = 0 ; i < VEC_SIZE; ++i) {
45+ scalar_op (dst.val [i], src.val [i]);
46+ }
5247 }
53- }
5448};
5549
5650template <int VEC_SIZE, typename InT, typename OutT, typename VecOp,
5751 typename ScaOp>
5852__device__ inline void vectorize_with_alignment (
59- const InT* in, OutT* out, int len, int tid, int stride,
60- VecOp&& vec_op, // vec_n_t<InT,16> -> vec_n_t<OutT,16>
61- ScaOp&& scalar_op) { // InT -> OutT
62- static_assert (VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1 )) == 0 ,
63- " VEC_SIZE must be a positive power-of-two" );
64- constexpr int WIDTH = VEC_SIZE * sizeof (InT); // eg: 64 B
65- uintptr_t addr = reinterpret_cast <uintptr_t >(in);
66-
67- // fast path when the whole region is already aligned
68- // Note: currently the output is guaranteed to be same as the input, so we
69- // don't check it here, comments here just for future reference.
70- bool can_vec = ((addr & (WIDTH - 1 )) == 0 ) && ((len & (VEC_SIZE - 1 )) == 0 );
71- if (can_vec) {
72- int num_vec = len / VEC_SIZE;
53+ const InT *in, OutT *out, int len, int tid, int stride,
54+ VecOp &&vec_op, // vec_n_t<InT,16> -> vec_n_t<OutT,16>
55+ ScaOp &&scalar_op) { // InT -> OutT
56+ static_assert (VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1 )) == 0 ,
57+ " VEC_SIZE must be a positive power-of-two" );
58+ constexpr int WIDTH = VEC_SIZE * sizeof (InT); // eg: 64 B
59+ uintptr_t addr = reinterpret_cast <uintptr_t >(in);
60+
61+ // fast path when the whole region is already aligned
62+ // Note: currently the output is guaranteed to be same as the input, so we
63+ // don't check it here, comments here just for future reference.
64+ bool can_vec = ((addr & (WIDTH - 1 )) == 0 ) && ((len & (VEC_SIZE - 1 )) == 0 );
65+ if (can_vec) {
66+ int num_vec = len / VEC_SIZE;
67+
68+ using vin_t = vec_n_t <InT, VEC_SIZE>;
69+ using vout_t = vec_n_t <OutT, VEC_SIZE>;
70+ auto *v_in = reinterpret_cast <const vin_t *>(in);
71+ auto *v_out = reinterpret_cast <vout_t *>(out);
72+
73+ for (int i = tid; i < num_vec; i += stride) {
74+ vout_t tmp;
75+ vec_op (tmp, v_in[i]);
76+ v_out[i] = tmp;
77+ }
78+ return ;
79+ }
80+
81+ int misalignment_offset = addr & (WIDTH - 1 ); // addr % 64
82+ int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
83+ int prefix_elems = alignment_bytes & (WIDTH - 1 ); // handle 64
84+ prefix_elems /= sizeof (InT);
85+ prefix_elems = min (prefix_elems, len); // 0 ≤ prefix < 16
86+
87+ // 1. prefill the when it is unsafe to vectorize
88+ for (int i = tid; i < prefix_elems; i += stride) {
89+ scalar_op (out[i], in[i]);
90+ }
91+
92+ in += prefix_elems;
93+ out += prefix_elems;
94+ len -= prefix_elems;
7395
96+ int num_vec = len / VEC_SIZE;
7497 using vin_t = vec_n_t <InT, VEC_SIZE>;
7598 using vout_t = vec_n_t <OutT, VEC_SIZE>;
76- auto * v_in = reinterpret_cast <const vin_t *>(in);
77- auto * v_out = reinterpret_cast <vout_t *>(out);
99+ auto * v_in = reinterpret_cast <const vin_t *>(in);
100+ auto * v_out = reinterpret_cast <vout_t *>(out);
78101
102+ // 2. vectorize the main part
79103 for (int i = tid; i < num_vec; i += stride) {
80- vout_t tmp;
81- vec_op (tmp, v_in[i]);
82- v_out[i] = tmp;
104+ vout_t tmp;
105+ vec_op (tmp, v_in[i]);
106+ v_out[i] = tmp;
83107 }
84- return ;
85- }
86-
87- int misalignment_offset = addr & (WIDTH - 1 ); // addr % 64
88- int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
89- int prefix_elems = alignment_bytes & (WIDTH - 1 ); // handle 64
90- prefix_elems /= sizeof (InT);
91- prefix_elems = min (prefix_elems, len); // 0 ≤ prefix < 16
92-
93- // 1. prefill the when it is unsafe to vectorize
94- for (int i = tid; i < prefix_elems; i += stride) {
95- scalar_op (out[i], in[i]);
96- }
97-
98- in += prefix_elems;
99- out += prefix_elems;
100- len -= prefix_elems;
101-
102- int num_vec = len / VEC_SIZE;
103- using vin_t = vec_n_t <InT, VEC_SIZE>;
104- using vout_t = vec_n_t <OutT, VEC_SIZE>;
105- auto * v_in = reinterpret_cast <const vin_t *>(in);
106- auto * v_out = reinterpret_cast <vout_t *>(out);
107-
108- // 2. vectorize the main part
109- for (int i = tid; i < num_vec; i += stride) {
110- vout_t tmp;
111- vec_op (tmp, v_in[i]);
112- v_out[i] = tmp;
113- }
114-
115- // 3. handle the tail
116- int tail_start = num_vec * VEC_SIZE;
117- for (int i = tid + tail_start; i < len; i += stride) {
118- scalar_op (out[i], in[i]);
119- }
120- }
121-
122108
109+ // 3. handle the tail
110+ int tail_start = num_vec * VEC_SIZE;
111+ for (int i = tid + tail_start; i < len; i += stride) {
112+ scalar_op (out[i], in[i]);
113+ }
114+ }
123115
124116template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
125- __device__ __forceinline__ void vectorize_with_alignment (const InT* in,
126- OutT* out, int len,
117+ __device__ __forceinline__ void vectorize_with_alignment (const InT * in,
118+ OutT * out, int len,
127119 int tid, int stride,
128- ScaOp&& scalar_op) {
129- using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t <ScaOp>>;
130- vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op},
131- std::forward<ScaOp>(scalar_op));
120+ ScaOp && scalar_op) {
121+ using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t <ScaOp>>;
122+ vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op},
123+ std::forward<ScaOp>(scalar_op));
132124}
133125
134-
135126template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
136127__global__ void reshape_and_cache_kernel (
137- const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
138- const scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
139- cache_t * __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
140- // block_size, x]
141- cache_t * __restrict__ value_cache, // [num_blocks, num_heads, head_size,
142- // block_size]
143- const int64_t * __restrict__ slot_mapping, // [num_tokens]
128+ const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
129+ const scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
130+ cache_t * __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
131+ // block_size, x]
132+ cache_t * __restrict__ value_cache, // [num_blocks, num_heads, head_size,
133+ // block_size]
134+ const int64_t * __restrict__ slot_mapping, // [num_tokens]
144135 const int key_stride, const int value_stride, const int num_heads,
145136 const int head_size, const int block_size, const int x,
146- const float * k_scale, const float * v_scale) {
147- const int64_t token_idx = blockIdx .x ;
148- const int64_t slot_idx = slot_mapping[token_idx];
149- if (slot_idx < 0 ) {
150- // Padding token that should be ignored.
151- return ;
152- }
153-
154- const int64_t block_idx = slot_idx / block_size;
155- const int64_t block_offset = slot_idx % block_size;
156-
157- const int n = num_heads * head_size;
158- for (int i = threadIdx .x ; i < n; i += blockDim .x ) {
159- const int64_t src_key_idx = token_idx * key_stride + i;
160- const int64_t src_value_idx = token_idx * value_stride + i;
161-
162- const int head_idx = i / head_size;
163- const int head_offset = i % head_size;
164- const int x_idx = head_offset / x;
165- const int x_offset = head_offset % x;
166-
167- const int64_t tgt_key_idx =
168- block_idx * num_heads * (head_size / x) * block_size * x +
169- head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
170- block_offset * x + x_offset;
171- const int64_t tgt_value_idx =
172- block_idx * num_heads * head_size * block_size +
173- head_idx * head_size * block_size + head_offset * block_size +
174- block_offset;
175- scalar_t tgt_key = key[src_key_idx];
176- scalar_t tgt_value = value[src_value_idx];
177- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
178- key_cache[tgt_key_idx] = tgt_key;
179- value_cache[tgt_value_idx] = tgt_value;
180- } else {
181- // key_cache[tgt_key_idx] =
182- // fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
183- // value_cache[tgt_value_idx] =
184- // fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
185- assert (false );
137+ const float *k_scale, const float *v_scale) {
138+ const int64_t token_idx = blockIdx .x ;
139+ const int64_t slot_idx = slot_mapping[token_idx];
140+ if (slot_idx < 0 ) {
141+ // Padding token that should be ignored.
142+ return ;
143+ }
144+
145+ const int64_t block_idx = slot_idx / block_size;
146+ const int64_t block_offset = slot_idx % block_size;
147+
148+ const int n = num_heads * head_size;
149+ for (int i = threadIdx .x ; i < n; i += blockDim .x ) {
150+ const int64_t src_key_idx = token_idx * key_stride + i;
151+ const int64_t src_value_idx = token_idx * value_stride + i;
152+
153+ const int head_idx = i / head_size;
154+ const int head_offset = i % head_size;
155+ const int x_idx = head_offset / x;
156+ const int x_offset = head_offset % x;
157+
158+ const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + block_offset * x + x_offset;
159+ const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size + head_idx * head_size * block_size + head_offset * block_size + block_offset;
160+ scalar_t tgt_key = key[src_key_idx];
161+ scalar_t tgt_value = value[src_value_idx];
162+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
163+ key_cache[tgt_key_idx] = tgt_key;
164+ value_cache[tgt_value_idx] = tgt_value;
165+ } else {
166+ // key_cache[tgt_key_idx] =
167+ // fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
168+ // value_cache[tgt_value_idx] =
169+ // fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
170+ assert (false );
171+ }
186172 }
187- }
188173}
189174} // namespace op::reshape_and_cache::metax
0 commit comments