Skip to content

Commit ce3f99b

Browse files
committed
format
1 parent 7572996 commit ce3f99b

File tree

6 files changed

+188
-204
lines changed

6 files changed

+188
-204
lines changed

python/infinicore/nn/functional/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,23 @@
44
from .linear import linear
55
from .paged_attention_v2 import paged_attention_v2
66
from .random_sample import random_sample
7+
from .reshape_and_cache import reshape_and_cache
78
from .rms_norm import rms_norm
89
from .rope import RopeAlgo, rope
910
from .silu import silu
1011
from .swiglu import swiglu
11-
from .reshape_and_cache import reshape_and_cache
1212

1313
__all__ = [
1414
"causal_softmax",
1515
"embedding",
1616
"flash_attention",
1717
"linear",
18+
"paged_attention_v2",
1819
"random_sample",
20+
"reshape_and_cache",
1921
"rms_norm",
2022
"RopeAlgo",
2123
"rope",
2224
"silu",
2325
"swiglu",
24-
"paged_attention_v2",
25-
"reshape_and_cache",
2626
]

python/infinicore/nn/functional/reshape_and_cache.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from infinicore.lib import _infinicore
2-
from infinicore.tensor import Tensor, empty
3-
2+
from infinicore.tensor import Tensor
43

54

65
def reshape_and_cache(
@@ -9,9 +8,9 @@ def reshape_and_cache(
98
key_cache: Tensor,
109
value_cache: Tensor,
1110
slot_mapping: Tensor,
12-
kv_cache_dtype:str,
11+
kv_cache_dtype: str,
1312
k_scale: Tensor,
14-
v_scale: Tensor ,
13+
v_scale: Tensor,
1514
):
1615
_infinicore.reshape_and_cache(
1716
key._underlying,

src/infinicore/pybind11/ops.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
#include "ops/paged_attention_prefill.hpp"
1717
#include "ops/paged_attention_v2.hpp"
1818
#include "ops/paged_caching.hpp"
19-
#include "ops/reshape_and_cache.hpp"
2019
#include "ops/random_sample.hpp"
2120
#include "ops/rearrange.hpp"
21+
#include "ops/reshape_and_cache.hpp"
2222
#include "ops/rms_norm.hpp"
2323
#include "ops/rope.hpp"
2424
#include "ops/silu.hpp"

src/infiniop/ops/reshape_and_cache/metax/reshape_and_cache_metax_kernels.cuh

Lines changed: 119 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -13,177 +13,162 @@ namespace op::reshape_and_cache::metax {
1313

1414
using Fp8KVCacheDataType = op::paged_attention_v2::vllm::Fp8KVCacheDataType;
1515

16-
17-
1816
// Used by vectorization_utils to copy/convert one element
1917
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
2018
struct CopyWithScaleOp {
21-
float scale;
22-
23-
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
24-
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
25-
dst = static_cast<OutT>(src);
26-
} else {
27-
// dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
28-
assert(false);
19+
float scale;
20+
21+
__device__ __forceinline__ void operator()(OutT &dst, const InT src) const {
22+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
23+
dst = static_cast<OutT>(src);
24+
} else {
25+
// dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
26+
assert(false);
27+
}
2928
}
30-
}
3129
};
3230

33-
34-
35-
3631
// Vectorization containers
3732
template <typename scalar_t, size_t vec_size>
3833
struct __align__(vec_size * sizeof(scalar_t)) vec_n_t {
39-
scalar_t val[vec_size];
34+
scalar_t val[vec_size];
4035
};
4136

42-
4337
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
4438
struct DefaultVecOp {
45-
ScaOp scalar_op;
39+
ScaOp scalar_op;
4640

47-
__device__ __forceinline__ void operator()(
48-
vec_n_t<OutT, VEC_SIZE>& dst, const vec_n_t<InT, VEC_SIZE>& src) const {
41+
__device__ __forceinline__ void operator()(
42+
vec_n_t<OutT, VEC_SIZE> &dst, const vec_n_t<InT, VEC_SIZE> &src) const {
4943
#pragma unroll
50-
for (int i = 0; i < VEC_SIZE; ++i) {
51-
scalar_op(dst.val[i], src.val[i]);
44+
for (int i = 0; i < VEC_SIZE; ++i) {
45+
scalar_op(dst.val[i], src.val[i]);
46+
}
5247
}
53-
}
5448
};
5549

5650
template <int VEC_SIZE, typename InT, typename OutT, typename VecOp,
5751
typename ScaOp>
5852
__device__ inline void vectorize_with_alignment(
59-
const InT* in, OutT* out, int len, int tid, int stride,
60-
VecOp&& vec_op, // vec_n_t<InT,16> -> vec_n_t<OutT,16>
61-
ScaOp&& scalar_op) { // InT -> OutT
62-
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
63-
"VEC_SIZE must be a positive power-of-two");
64-
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
65-
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
66-
67-
// fast path when the whole region is already aligned
68-
// Note: currently the output is guaranteed to be same as the input, so we
69-
// don't check it here, comments here just for future reference.
70-
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
71-
if (can_vec) {
72-
int num_vec = len / VEC_SIZE;
53+
const InT *in, OutT *out, int len, int tid, int stride,
54+
VecOp &&vec_op, // vec_n_t<InT,16> -> vec_n_t<OutT,16>
55+
ScaOp &&scalar_op) { // InT -> OutT
56+
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
57+
"VEC_SIZE must be a positive power-of-two");
58+
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
59+
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
60+
61+
// fast path when the whole region is already aligned
62+
// Note: currently the output is guaranteed to be same as the input, so we
63+
// don't check it here, comments here just for future reference.
64+
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
65+
if (can_vec) {
66+
int num_vec = len / VEC_SIZE;
67+
68+
using vin_t = vec_n_t<InT, VEC_SIZE>;
69+
using vout_t = vec_n_t<OutT, VEC_SIZE>;
70+
auto *v_in = reinterpret_cast<const vin_t *>(in);
71+
auto *v_out = reinterpret_cast<vout_t *>(out);
72+
73+
for (int i = tid; i < num_vec; i += stride) {
74+
vout_t tmp;
75+
vec_op(tmp, v_in[i]);
76+
v_out[i] = tmp;
77+
}
78+
return;
79+
}
80+
81+
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
82+
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
83+
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
84+
prefix_elems /= sizeof(InT);
85+
prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16
86+
87+
// 1. prefill the when it is unsafe to vectorize
88+
for (int i = tid; i < prefix_elems; i += stride) {
89+
scalar_op(out[i], in[i]);
90+
}
91+
92+
in += prefix_elems;
93+
out += prefix_elems;
94+
len -= prefix_elems;
7395

96+
int num_vec = len / VEC_SIZE;
7497
using vin_t = vec_n_t<InT, VEC_SIZE>;
7598
using vout_t = vec_n_t<OutT, VEC_SIZE>;
76-
auto* v_in = reinterpret_cast<const vin_t*>(in);
77-
auto* v_out = reinterpret_cast<vout_t*>(out);
99+
auto *v_in = reinterpret_cast<const vin_t *>(in);
100+
auto *v_out = reinterpret_cast<vout_t *>(out);
78101

102+
// 2. vectorize the main part
79103
for (int i = tid; i < num_vec; i += stride) {
80-
vout_t tmp;
81-
vec_op(tmp, v_in[i]);
82-
v_out[i] = tmp;
104+
vout_t tmp;
105+
vec_op(tmp, v_in[i]);
106+
v_out[i] = tmp;
83107
}
84-
return;
85-
}
86-
87-
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
88-
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
89-
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
90-
prefix_elems /= sizeof(InT);
91-
prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16
92-
93-
// 1. prefill the when it is unsafe to vectorize
94-
for (int i = tid; i < prefix_elems; i += stride) {
95-
scalar_op(out[i], in[i]);
96-
}
97-
98-
in += prefix_elems;
99-
out += prefix_elems;
100-
len -= prefix_elems;
101-
102-
int num_vec = len / VEC_SIZE;
103-
using vin_t = vec_n_t<InT, VEC_SIZE>;
104-
using vout_t = vec_n_t<OutT, VEC_SIZE>;
105-
auto* v_in = reinterpret_cast<const vin_t*>(in);
106-
auto* v_out = reinterpret_cast<vout_t*>(out);
107-
108-
// 2. vectorize the main part
109-
for (int i = tid; i < num_vec; i += stride) {
110-
vout_t tmp;
111-
vec_op(tmp, v_in[i]);
112-
v_out[i] = tmp;
113-
}
114-
115-
// 3. handle the tail
116-
int tail_start = num_vec * VEC_SIZE;
117-
for (int i = tid + tail_start; i < len; i += stride) {
118-
scalar_op(out[i], in[i]);
119-
}
120-
}
121-
122108

109+
// 3. handle the tail
110+
int tail_start = num_vec * VEC_SIZE;
111+
for (int i = tid + tail_start; i < len; i += stride) {
112+
scalar_op(out[i], in[i]);
113+
}
114+
}
123115

124116
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
125-
__device__ __forceinline__ void vectorize_with_alignment(const InT* in,
126-
OutT* out, int len,
117+
__device__ __forceinline__ void vectorize_with_alignment(const InT *in,
118+
OutT *out, int len,
127119
int tid, int stride,
128-
ScaOp&& scalar_op) {
129-
using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>;
130-
vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op},
131-
std::forward<ScaOp>(scalar_op));
120+
ScaOp &&scalar_op) {
121+
using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>;
122+
vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op},
123+
std::forward<ScaOp>(scalar_op));
132124
}
133125

134-
135126
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
136127
__global__ void reshape_and_cache_kernel(
137-
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
138-
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
139-
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
140-
// block_size, x]
141-
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
142-
// block_size]
143-
const int64_t* __restrict__ slot_mapping, // [num_tokens]
128+
const scalar_t *__restrict__ key, // [num_tokens, num_heads, head_size]
129+
const scalar_t *__restrict__ value, // [num_tokens, num_heads, head_size]
130+
cache_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
131+
// block_size, x]
132+
cache_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size,
133+
// block_size]
134+
const int64_t *__restrict__ slot_mapping, // [num_tokens]
144135
const int key_stride, const int value_stride, const int num_heads,
145136
const int head_size, const int block_size, const int x,
146-
const float* k_scale, const float* v_scale) {
147-
const int64_t token_idx = blockIdx.x;
148-
const int64_t slot_idx = slot_mapping[token_idx];
149-
if (slot_idx < 0) {
150-
// Padding token that should be ignored.
151-
return;
152-
}
153-
154-
const int64_t block_idx = slot_idx / block_size;
155-
const int64_t block_offset = slot_idx % block_size;
156-
157-
const int n = num_heads * head_size;
158-
for (int i = threadIdx.x; i < n; i += blockDim.x) {
159-
const int64_t src_key_idx = token_idx * key_stride + i;
160-
const int64_t src_value_idx = token_idx * value_stride + i;
161-
162-
const int head_idx = i / head_size;
163-
const int head_offset = i % head_size;
164-
const int x_idx = head_offset / x;
165-
const int x_offset = head_offset % x;
166-
167-
const int64_t tgt_key_idx =
168-
block_idx * num_heads * (head_size / x) * block_size * x +
169-
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
170-
block_offset * x + x_offset;
171-
const int64_t tgt_value_idx =
172-
block_idx * num_heads * head_size * block_size +
173-
head_idx * head_size * block_size + head_offset * block_size +
174-
block_offset;
175-
scalar_t tgt_key = key[src_key_idx];
176-
scalar_t tgt_value = value[src_value_idx];
177-
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
178-
key_cache[tgt_key_idx] = tgt_key;
179-
value_cache[tgt_value_idx] = tgt_value;
180-
} else {
181-
// key_cache[tgt_key_idx] =
182-
// fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
183-
// value_cache[tgt_value_idx] =
184-
// fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
185-
assert(false);
137+
const float *k_scale, const float *v_scale) {
138+
const int64_t token_idx = blockIdx.x;
139+
const int64_t slot_idx = slot_mapping[token_idx];
140+
if (slot_idx < 0) {
141+
// Padding token that should be ignored.
142+
return;
143+
}
144+
145+
const int64_t block_idx = slot_idx / block_size;
146+
const int64_t block_offset = slot_idx % block_size;
147+
148+
const int n = num_heads * head_size;
149+
for (int i = threadIdx.x; i < n; i += blockDim.x) {
150+
const int64_t src_key_idx = token_idx * key_stride + i;
151+
const int64_t src_value_idx = token_idx * value_stride + i;
152+
153+
const int head_idx = i / head_size;
154+
const int head_offset = i % head_size;
155+
const int x_idx = head_offset / x;
156+
const int x_offset = head_offset % x;
157+
158+
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + block_offset * x + x_offset;
159+
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size + head_idx * head_size * block_size + head_offset * block_size + block_offset;
160+
scalar_t tgt_key = key[src_key_idx];
161+
scalar_t tgt_value = value[src_value_idx];
162+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
163+
key_cache[tgt_key_idx] = tgt_key;
164+
value_cache[tgt_value_idx] = tgt_value;
165+
} else {
166+
// key_cache[tgt_key_idx] =
167+
// fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
168+
// value_cache[tgt_value_idx] =
169+
// fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
170+
assert(false);
171+
}
186172
}
187-
}
188173
}
189174
} // namespace op::reshape_and_cache::metax

0 commit comments

Comments
 (0)