Skip to content

Commit ec99504

Browse files
committed
issue/923 - ninetoothed kv caching for nv, il, mtx
1 parent 7249233 commit ec99504

File tree

16 files changed

+682
-4
lines changed

16 files changed

+682
-4
lines changed

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "ops/add_rms_norm.hpp"
55
#include "ops/attention.hpp"
66
#include "ops/causal_softmax.hpp"
7+
#include "ops/kv_caching.hpp"
78
#include "ops/matmul.hpp"
89
#include "ops/ones.hpp"
910
#include "ops/paged_attention.hpp"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "common/op.hpp"
6+
7+
namespace infinicore::op {
8+
9+
INFINICORE_GRAPH_OP_CLASS(KVCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);
10+
11+
void kv_caching_(Tensor k_cache,
12+
Tensor v_cache,
13+
const Tensor &k,
14+
const Tensor &v,
15+
const Tensor &past_kv_lengths);
16+
} // namespace infinicore::op

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "infiniop/ops/dequantize_awq.h"
1212
#include "infiniop/ops/gelu.h"
1313
#include "infiniop/ops/gemm.h"
14+
#include "infiniop/ops/kv_caching.h"
1415
#include "infiniop/ops/layer_norm.h"
1516
#include "infiniop/ops/logsoftmax.h"
1617
#include "infiniop/ops/lp_norm.h"

include/infiniop/ops/kv_caching.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef __INFINIOP_KV_CACHING_API_H__
2+
#define __INFINIOP_KV_CACHING_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateKVCachingDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopKVCachingDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t k_cache,
12+
infiniopTensorDescriptor_t v_cache,
13+
infiniopTensorDescriptor_t k,
14+
infiniopTensorDescriptor_t v,
15+
infiniopTensorDescriptor_t past_kv_lengths);
16+
17+
__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size);
18+
19+
__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc,
20+
void *workspace,
21+
size_t workspace_size,
22+
void *k_cache,
23+
void *v_cache,
24+
const void *k,
25+
const void *v,
26+
const void *past_kv_lengths,
27+
void *stream);
28+
29+
__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc);
30+
31+
#endif

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from infinicore.ops.add import add
4646
from infinicore.ops.add_rms_norm import add_rms_norm
4747
from infinicore.ops.attention import attention
48+
from infinicore.ops.kv_caching import kv_caching
4849
from infinicore.ops.matmul import matmul
4950
from infinicore.ops.mul import mul
5051
from infinicore.ops.narrow import narrow
@@ -115,6 +116,7 @@
115116
"add_rms_norm",
116117
"add_rms_norm_",
117118
"attention",
119+
"kv_caching",
118120
"matmul",
119121
"mul",
120122
"narrow",
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from infinicore.lib import _infinicore
2+
3+
4+
def kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
5+
_infinicore.kv_caching_(
6+
k_cache._underlying,
7+
v_cache._underlying,
8+
k._underlying,
9+
v._underlying,
10+
past_kv_lengths._underlying,
11+
)
12+
13+
return k_cache, v_cache
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "infinicore/ops/kv_caching.hpp"
2+
3+
#include "../../utils.hpp"
4+
5+
namespace infinicore::op {
6+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(KVCaching);
7+
8+
KVCaching::KVCaching(Tensor k_cache,
9+
Tensor v_cache,
10+
const Tensor &k,
11+
const Tensor &v,
12+
const Tensor &past_kv_lengths) {
13+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, past_kv_lengths);
14+
INFINICORE_GRAPH_OP_DISPATCH(k_cache->device().getType(),
15+
k_cache,
16+
v_cache,
17+
k,
18+
v,
19+
past_kv_lengths);
20+
}
21+
22+
void KVCaching::execute(Tensor k_cache,
23+
Tensor v_cache,
24+
const Tensor &k,
25+
const Tensor &v,
26+
const Tensor &past_kv_lengths) {
27+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(KVCaching,
28+
k_cache,
29+
v_cache,
30+
k,
31+
v,
32+
past_kv_lengths);
33+
}
34+
35+
void kv_caching_(Tensor k_cache,
36+
Tensor v_cache,
37+
const Tensor &k,
38+
const Tensor &v,
39+
const Tensor &past_kv_lengths) {
40+
KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths);
41+
}
42+
} // namespace infinicore::op
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "../infiniop_impl.hpp"
2+
#include "infinicore/ops/kv_caching.hpp"
3+
4+
namespace infinicore::op::kv_caching_impl::infiniop {
5+
6+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, KVCaching, 100);
7+
8+
struct PlannedMeta {
9+
std::shared_ptr<Descriptor> descriptor;
10+
graph::GraphTensor workspace, k_cache, v_cache, k, v, past_kv_lengths;
11+
};
12+
13+
void *plan(Tensor k_cache,
14+
Tensor v_cache,
15+
const Tensor &k,
16+
const Tensor &v,
17+
const Tensor &past_kv_lengths) {
18+
size_t seed = hash_combine(k_cache, v_cache, k, v, past_kv_lengths);
19+
20+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
21+
Descriptor, descriptor, KVCaching,
22+
seed, k_cache->desc(), v_cache->desc(),
23+
k->desc(), v->desc(), past_kv_lengths->desc());
24+
25+
INFINIOP_WORKSPACE_TENSOR(workspace, KVCaching, descriptor);
26+
27+
auto planned = new PlannedMeta{
28+
descriptor,
29+
graph::GraphTensor(workspace),
30+
graph::GraphTensor(k_cache),
31+
graph::GraphTensor(v_cache),
32+
graph::GraphTensor(k),
33+
graph::GraphTensor(v),
34+
graph::GraphTensor(past_kv_lengths)};
35+
36+
return planned;
37+
}
38+
39+
void run(void *planned_meta) {
40+
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
41+
42+
INFINICORE_CHECK_ERROR(infiniopKVCaching(
43+
planned->descriptor->desc,
44+
nullptr, 0,
45+
planned->k_cache->data(),
46+
planned->v_cache->data(),
47+
planned->k->data(),
48+
planned->v->data(),
49+
planned->past_kv_lengths->data(),
50+
context::getStream()));
51+
}
52+
53+
void cleanup(void **planned_meta_ptr) {
54+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
55+
*planned_meta_ptr = nullptr;
56+
}
57+
58+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(KVCaching, &plan, &run, cleanup);
59+
60+
} // namespace infinicore::op::kv_caching_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ops/attention.hpp"
88
#include "ops/causal_softmax.hpp"
99
#include "ops/embedding.hpp"
10+
#include "ops/kv_caching.hpp"
1011
#include "ops/linear.hpp"
1112
#include "ops/matmul.hpp"
1213
#include "ops/mul.hpp"
@@ -29,13 +30,14 @@ inline void bind(py::module &m) {
2930
bind_add_rms_norm(m);
3031
bind_attention(m);
3132
bind_causal_softmax(m);
32-
bind_random_sample(m);
33+
bind_kv_caching(m);
3334
bind_linear(m);
3435
bind_matmul(m);
3536
bind_mul(m);
3637
bind_paged_attention(m);
3738
bind_paged_attention_prefill(m);
3839
bind_paged_caching(m);
40+
bind_random_sample(m);
3941
bind_rearrange(m);
4042
bind_rms_norm(m);
4143
bind_silu(m);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/kv_caching.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_kv_caching(py::module &m) {
12+
m.def("kv_caching_",
13+
&op::kv_caching_,
14+
py::arg("k_cache"),
15+
py::arg("v_cache"),
16+
py::arg("k"),
17+
py::arg("v"),
18+
py::arg("past_kv_lengths"),
19+
R"doc(In-place Key-Value Caching.
20+
21+
Updates the KV cache in-place with new key and value tensors.
22+
23+
Args:
24+
k_cache: Key cache tensor to update in-place
25+
v_cache: Value cache tensor to update in-place
26+
k: New key tensor to append
27+
v: New value tensor to append
28+
past_kv_lengths: Tensor containing current sequence lengths for each batch
29+
)doc");
30+
}
31+
32+
} // namespace infinicore::ops

0 commit comments

Comments
 (0)