Skip to content

Commit 7db2675

Browse files
committed
issue/923 - ninetoothed kv_caching
1 parent 1f4550c commit 7db2675

File tree

13 files changed

+456
-49
lines changed

13 files changed

+456
-49
lines changed

include/infinicore/ops/kv_caching.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#pragma
1+
#pragma once
22

33
#include "../device.hpp"
44
#include "common/op.hpp"
@@ -15,11 +15,6 @@ class KVCaching {
1515
static common::OpDispatcher<schema> &dispatcher();
1616
};
1717

18-
Tensor kv_caching(Tensor k_cache,
19-
Tensor v_cache,
20-
Tensor k,
21-
Tensor v,
22-
Tensor past_kv_lengths);
2318
void kv_caching_(Tensor k_cache,
2419
Tensor v_cache,
2520
Tensor k,

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from infinicore.ops.add import add
4646
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
4747
from infinicore.ops.attention import attention
48+
from infinicore.ops.kv_caching import kv_caching
4849
from infinicore.ops.matmul import matmul
4950
from infinicore.ops.mul import mul
5051
from infinicore.ops.narrow import narrow
@@ -115,6 +116,7 @@
115116
"add_rms_norm",
116117
"add_rms_norm_",
117118
"attention",
119+
"kv_caching",
118120
"matmul",
119121
"mul",
120122
"narrow",
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from infinicore.lib import _infinicore
2+
3+
4+
def kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
5+
_infinicore.kv_caching_(
6+
k_cache._underlying,
7+
v_cache._underlying,
8+
k._underlying,
9+
v._underlying,
10+
past_kv_lengths._underlying,
11+
)
12+
13+
return k_cache, v_cache

src/infinicore/ops/kv_caching/kv_caching.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,6 @@ void KVCaching::execute(Tensor k_cache,
2828
func(k_cache, v_cache, k, v, past_kv_lengths);
2929
}
3030

31-
Tensor kv_caching(Tensor k_cache,
32-
Tensor v_cache,
33-
Tensor k,
34-
Tensor v,
35-
Tensor past_kv_lengths) {
36-
KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths);
37-
return k_cache; // or v_cache, depending on the intended use
38-
}
39-
4031
void kv_caching_(Tensor k_cache,
4132
Tensor v_cache,
4233
Tensor k,

src/infinicore/pybind11/ops.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "ops/causal_softmax.hpp"
99
#include "ops/embedding.hpp"
1010
#include "ops/flash_attention.hpp"
11+
#include "ops/kv_caching.hpp"
1112
#include "ops/linear.hpp"
1213
#include "ops/matmul.hpp"
1314
#include "ops/mul.hpp"
@@ -30,20 +31,21 @@ inline void bind(py::module &m) {
3031
bind_add_rms_norm(m);
3132
bind_attention(m);
3233
bind_causal_softmax(m);
34+
bind_embedding(m);
3335
bind_flash_attention(m);
34-
bind_random_sample(m);
36+
bind_kv_caching(m);
3537
bind_linear(m);
3638
bind_matmul(m);
3739
bind_mul(m);
3840
bind_paged_attention(m);
3941
bind_paged_attention_prefill(m);
4042
bind_paged_caching(m);
43+
bind_random_sample(m);
4144
bind_rearrange(m);
4245
bind_rms_norm(m);
46+
bind_rope(m);
4347
bind_silu(m);
4448
bind_swiglu(m);
45-
bind_rope(m);
46-
bind_embedding(m);
4749
}
4850

4951
} // namespace infinicore::ops
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/kv_caching.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_kv_caching(py::module &m) {
12+
m.def("kv_caching_",
13+
&op::kv_caching_,
14+
py::arg("k_cache"),
15+
py::arg("v_cache"),
16+
py::arg("k"),
17+
py::arg("v"),
18+
py::arg("past_kv_lengths"),
19+
R"doc(In-place Key-Value Caching.
20+
21+
Updates the KV cache in-place with new key and value tensors.
22+
23+
Args:
24+
k_cache: Key cache tensor to update in-place
25+
v_cache: Value cache tensor to update in-place
26+
k: New key tensor to append
27+
v: New value tensor to append
28+
past_kv_lengths: Tensor containing current sequence lengths for each batch
29+
)doc");
30+
}
31+
32+
} // namespace infinicore::ops

src/infiniop/ops/flash_attention/ninetoothed/descriptor.h

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,26 @@ class Descriptor final : public InfiniopDescriptor {
6767
constexpr auto block_size_m_{64};
6868
constexpr auto block_size_n_{64};
6969

70-
launch_flash_attention(stream,
71-
query,
72-
key,
73-
value,
74-
attn_mask,
75-
is_causal,
76-
scale,
77-
output,
78-
with_attn_mask,
79-
causal_variant,
80-
with_kv_cache_,
81-
emb_dim_,
82-
is_causal_,
83-
with_attn_mask_,
84-
causal_variant_,
85-
dtype_,
86-
block_size_m_,
87-
block_size_n_);
70+
if (launch_flash_attention(stream,
71+
query,
72+
key,
73+
value,
74+
attn_mask,
75+
is_causal,
76+
scale,
77+
output,
78+
with_attn_mask,
79+
causal_variant,
80+
with_kv_cache_,
81+
emb_dim_,
82+
is_causal_,
83+
with_attn_mask_,
84+
causal_variant_,
85+
dtype_,
86+
block_size_m_,
87+
block_size_n_)) {
88+
return INFINI_STATUS_NOT_IMPLEMENTED;
89+
}
8890

8991
return INFINI_STATUS_SUCCESS;
9092
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import ninetoothed
2+
from . import kv_caching
3+
4+
import infiniop.ninetoothed.build
5+
6+
7+
def build():
8+
dtype_values = (
9+
ninetoothed.float16,
10+
ninetoothed.bfloat16,
11+
ninetoothed.float32,
12+
)
13+
14+
constexpr_param_grid = {
15+
"emb_dim": (1, 16, 32, 64, 128, 256),
16+
"dtype": dtype_values,
17+
"block_size_m": (64,),
18+
"block_size_n": (64,),
19+
}
20+
21+
infiniop.ninetoothed.build.build(
22+
kv_caching.premake,
23+
constexpr_param_grid,
24+
caller="cuda",
25+
op_name="kv_caching",
26+
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
27+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#ifndef KV_CACHING_H
2+
#define KV_CACHING_H
3+
4+
#include "../../../handle.h"
5+
#include "../../../operator.h"
6+
#include "../../../tensor.h"
7+
8+
#include "../../../../../build/ninetoothed/kv_caching.h"
9+
#include "../../../ninetoothed/utils.h"
10+
11+
namespace op::kv_caching::ninetoothed {
12+
class Descriptor final : public InfiniopDescriptor {
13+
14+
public:
15+
Descriptor(
16+
infiniopHandle_t handle,
17+
infiniopTensorDescriptor_t k_cache_desc,
18+
infiniopTensorDescriptor_t v_cache_desc,
19+
infiniopTensorDescriptor_t k_desc,
20+
infiniopTensorDescriptor_t v_desc,
21+
infiniopTensorDescriptor_t past_kv_lengths_desc) : InfiniopDescriptor{handle->device, handle->device_id},
22+
k_cache_shape_{k_cache_desc->shape()},
23+
k_cache_strides_{k_cache_desc->strides()},
24+
v_cache_shape_{v_cache_desc->shape()},
25+
v_cache_strides_{v_cache_desc->strides()},
26+
k_shape_{k_desc->shape()},
27+
k_strides_{k_desc->strides()},
28+
v_shape_{v_desc->shape()},
29+
v_strides_{v_desc->strides()},
30+
past_kv_lengths_shape_{past_kv_lengths_desc->shape()},
31+
past_kv_lengths_strides_{past_kv_lengths_desc->strides()},
32+
dtype_{k_desc->dtype()} {}
33+
34+
~Descriptor() = default;
35+
36+
size_t get_workspace_size() const { return 0; };
37+
38+
static infiniStatus_t create(
39+
infiniopHandle_t handle,
40+
Descriptor **desc_ptr,
41+
infiniopTensorDescriptor_t k_cache,
42+
infiniopTensorDescriptor_t v_cache,
43+
infiniopTensorDescriptor_t k,
44+
infiniopTensorDescriptor_t v,
45+
infiniopTensorDescriptor_t past_kv_lengths) {
46+
*desc_ptr = new Descriptor{handle, k_cache, v_cache, k, v, past_kv_lengths};
47+
return INFINI_STATUS_SUCCESS;
48+
}
49+
50+
infiniStatus_t calculate(
51+
void *workspace, size_t workspace_size,
52+
void *k_cache,
53+
void *v_cache,
54+
const void *k,
55+
const void *v,
56+
const void *past_kv_lengths,
57+
void *stream) const {
58+
auto k_cache_nt{::ninetoothed::Tensor{k_cache, k_cache_shape_, k_cache_strides_}};
59+
auto v_cache_nt{::ninetoothed::Tensor{v_cache, v_cache_shape_, v_cache_strides_}};
60+
auto k_nt{::ninetoothed::Tensor{k, k_shape_, k_strides_}};
61+
auto v_nt{::ninetoothed::Tensor{v, v_shape_, v_strides_}};
62+
auto past_kv_lengths_nt{::ninetoothed::Tensor{past_kv_lengths, past_kv_lengths_shape_, past_kv_lengths_strides_}};
63+
64+
if (launch_kv_caching(stream,
65+
k_cache_nt,
66+
v_cache_nt,
67+
k_nt,
68+
v_nt,
69+
past_kv_lengths_nt,
70+
k_shape_[3],
71+
dtype_,
72+
64, 64)) {
73+
return INFINI_STATUS_NOT_IMPLEMENTED;
74+
}
75+
76+
return INFINI_STATUS_SUCCESS;
77+
}
78+
79+
private:
80+
using Size = ::ninetoothed::Tensor<>::Size;
81+
using Stride = ::ninetoothed::Tensor<>::Stride;
82+
83+
std::vector<Size> k_cache_shape_;
84+
std::vector<Stride> k_cache_strides_;
85+
86+
std::vector<Size> v_cache_shape_;
87+
std::vector<Stride> v_cache_strides_;
88+
89+
std::vector<Size> k_shape_;
90+
std::vector<Stride> k_strides_;
91+
std::vector<Size> v_shape_;
92+
std::vector<Stride> v_strides_;
93+
94+
std::vector<Size> past_kv_lengths_shape_;
95+
std::vector<Stride> past_kv_lengths_strides_;
96+
97+
infiniDtype_t dtype_;
98+
};
99+
} // namespace op::kv_caching::ninetoothed
100+
101+
#endif // KV_CACHING_H
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import functools
2+
import ninetoothed
3+
from ninetoothed import Tensor
4+
5+
6+
def arrangement(
7+
k_cache,
8+
v_cache,
9+
k,
10+
v,
11+
past_lengths,
12+
block_size_m=ninetoothed.block_size(),
13+
block_size_n=ninetoothed.block_size(),
14+
):
15+
k_cache_arranged = k_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
16+
v_cache_arranged = v_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
17+
18+
k_arranged = k.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
19+
v_arranged = v.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
20+
21+
past_lengths_arranged = (
22+
past_lengths.tile((1,))
23+
.unsqueeze(1)
24+
.unsqueeze(2)
25+
.unsqueeze(3)
26+
.unsqueeze(4)
27+
.expand((-1, *k_arranged.shape))
28+
)
29+
30+
return (
31+
k_cache_arranged,
32+
v_cache_arranged,
33+
k_arranged,
34+
v_arranged,
35+
past_lengths_arranged,
36+
)
37+
38+
39+
def application(k_cache, v_cache, k, v, past_lengths):
40+
pos = past_lengths
41+
42+
for i in range(k.shape[-2]):
43+
k_cache[0, 0, pos + i, 0] = k[0, 0, i, 0]
44+
v_cache[0, 0, pos + i, 0] = v[0, 0, i, 0]
45+
46+
47+
def premake(emb_dim=None, dtype=None, block_size_m=None, block_size_n=None):
48+
arrangement_ = functools.partial(
49+
arrangement, block_size_m=block_size_m, block_size_n=block_size_n
50+
)
51+
52+
shape_options = (None, None, None, {"constexpr": True, "upper_bound": 256})
53+
54+
tensors = (
55+
Tensor(4, dtype=dtype, shape_options=shape_options),
56+
Tensor(4, dtype=dtype, shape_options=shape_options),
57+
Tensor(4, dtype=dtype, shape_options=shape_options),
58+
Tensor(4, dtype=dtype, shape_options=shape_options),
59+
Tensor(1, dtype=ninetoothed.int64),
60+
)
61+
62+
if emb_dim is not None:
63+
for tensor in tensors:
64+
tensor.shape = tensor.shape[:-1] + (emb_dim,)
65+
66+
return arrangement_, application, tensors

0 commit comments

Comments
 (0)