Skip to content

Commit 7b1d1dc

Browse files
committed
nvidia cache run
1 parent 7ee139d commit 7b1d1dc

File tree

19 files changed

+948
-1
lines changed

19 files changed

+948
-1
lines changed

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ops/paged_caching.hpp"
1616
#include "ops/random_sample.hpp"
1717
#include "ops/rearrange.hpp"
18+
#include "ops/reshape_and_cache.hpp"
1819
#include "ops/rms_norm.hpp"
1920
#include "ops/rope.hpp"
2021
#include "ops/silu.hpp"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "common/op.hpp"
6+
7+
namespace infinicore::op {
8+
9+
INFINICORE_GRAPH_OP_CLASS(ReshapeAndCache, Tensor &, Tensor &, Tensor &, Tensor &, Tensor &,
10+
const std::string &, Tensor &, Tensor &);
11+
12+
void reshape_and_cache(Tensor &key, // [num_tokens, num_heads, head_size]
13+
Tensor &value, // [num_tokens, num_heads, head_size]
14+
Tensor &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
15+
Tensor &value_cache, // [num_blocks, num_heads, head_size, block_size]
16+
Tensor &slot_mapping, // [num_tokens]
17+
const std::string &kv_cache_dtype,
18+
Tensor &k_scale,
19+
Tensor &v_scale);
20+
21+
} // namespace infinicore::op

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "infiniop/ops/random_sample.h"
2727
#include "infiniop/ops/rearrange.h"
2828
#include "infiniop/ops/relu.h"
29+
#include "infiniop/ops/reshape_and_cache.h"
2930
#include "infiniop/ops/rms_norm.h"
3031
#include "infiniop/ops/rope.h"
3132
#include "infiniop/ops/sigmoid.h"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef __INFINIOP_RESHAPE_AND_CACHE_API_H__
2+
#define __INFINIOP_RESHAPE_AND_CACHE_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
#include <stdint.h>
6+
7+
typedef struct InfiniopDescriptor *infiniopReshapeAndCacheDescriptor_t;
8+
9+
__C __export infiniStatus_t infiniopCreateReshapeAndCacheDescriptor(
10+
infiniopHandle_t handle,
11+
infiniopReshapeAndCacheDescriptor_t *desc_ptr,
12+
infiniopTensorDescriptor_t key_desc,
13+
infiniopTensorDescriptor_t value_desc,
14+
infiniopTensorDescriptor_t key_cache_desc,
15+
infiniopTensorDescriptor_t value_cache_desc,
16+
infiniopTensorDescriptor_t slot_mapping_desc,
17+
const char *kv_cache_dtype);
18+
19+
__C __export infiniStatus_t infiniopGetReshapeAndCacheWorkspaceSize(
20+
infiniopReshapeAndCacheDescriptor_t desc, size_t *size);
21+
22+
__C __export infiniStatus_t infiniopReshapeAndCache(
23+
infiniopReshapeAndCacheDescriptor_t desc,
24+
void *workspace,
25+
size_t workspace_size,
26+
void *key,
27+
void *value,
28+
void *key_cache,
29+
void *value_cache,
30+
const void *slot_mapping,
31+
const char *kv_cache_dtype,
32+
void *k_scale,
33+
void *v_scale,
34+
void *stream);
35+
36+
__C __export infiniStatus_t infiniopDestroyReshapeAndCacheDescriptor(
37+
infiniopReshapeAndCacheDescriptor_t desc);
38+
39+
#endif // __INFINIOP_RESHAPE_AND_CACHE_API_H__

python/infinicore/nn/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .rope import RopeAlgo, rope
99
from .silu import silu
1010
from .swiglu import swiglu
11+
from .reshape_and_cache import reshape_and_cache
1112

1213
__all__ = [
1314
"causal_softmax",
@@ -21,4 +22,5 @@
2122
"silu",
2223
"swiglu",
2324
"paged_attention_v2",
25+
"reshape_and_cache",
2426
]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor, empty
3+
4+
5+
6+
def reshape_and_cache(
7+
key: Tensor,
8+
value: Tensor,
9+
key_cache: Tensor,
10+
value_cache: Tensor,
11+
slot_mapping: Tensor,
12+
kv_cache_dtype:str,
13+
k_scale: Tensor,
14+
v_scale: Tensor ,
15+
):
16+
_infinicore.reshape_and_cache(
17+
key._underlying,
18+
value._underlying,
19+
key_cache._underlying,
20+
value_cache._underlying,
21+
slot_mapping._underlying,
22+
kv_cache_dtype,
23+
k_scale._underlying,
24+
v_scale._underlying,
25+
)

python/infinicore/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,10 @@ def from_torch(torch_tensor) -> Tensor:
185185
infini_type = to_infinicore_dtype(torch_tensor.dtype)
186186
infini_device = infinicore.device(torch_tensor.device.type, 0)
187187
return Tensor(
188-
_infinicore.from_blob(
188+
_infinicore.strided_from_blob(
189189
torch_tensor.data_ptr(),
190190
list(torch_tensor.shape),
191+
list(torch_tensor.stride()),
191192
dtype=infini_type._underlying,
192193
device=infini_device._underlying,
193194
),
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "infinicore/ops/reshape_and_cache.hpp"
2+
#include "../../utils.hpp"
3+
4+
namespace infinicore::op {
5+
6+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(ReshapeAndCache);
7+
8+
ReshapeAndCache::ReshapeAndCache(Tensor &key,
9+
Tensor &value,
10+
Tensor &key_cache,
11+
Tensor &value_cache,
12+
Tensor &slot_mapping,
13+
const std::string &kv_cache_dtype,
14+
Tensor &k_scale,
15+
Tensor &v_scale) {
16+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(key, value, key_cache, value_cache, slot_mapping);
17+
INFINICORE_GRAPH_OP_DISPATCH(key->device().getType(),
18+
key, value, key_cache, value_cache, slot_mapping,
19+
kv_cache_dtype, k_scale, v_scale);
20+
}
21+
22+
void ReshapeAndCache::execute(Tensor &key,
23+
Tensor &value,
24+
Tensor &key_cache,
25+
Tensor &value_cache,
26+
Tensor &slot_mapping,
27+
const std::string &kv_cache_dtype,
28+
Tensor &k_scale,
29+
Tensor &v_scale) {
30+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
31+
ReshapeAndCache,
32+
key, value, key_cache, value_cache, slot_mapping,
33+
kv_cache_dtype, k_scale, v_scale);
34+
}
35+
36+
void reshape_and_cache(Tensor &key, // [num_tokens, num_heads, head_size]
37+
Tensor &value, // [num_tokens, num_heads, head_size]
38+
Tensor &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
39+
Tensor &value_cache, // [num_blocks, num_heads, head_size, block_size]
40+
Tensor &slot_mapping, // [num_tokens]
41+
const std::string &kv_cache_dtype,
42+
Tensor &k_scale,
43+
Tensor &v_scale) {
44+
ReshapeAndCache::execute(key, value, key_cache, value_cache, slot_mapping,
45+
kv_cache_dtype, k_scale, v_scale);
46+
}
47+
48+
} // namespace infinicore::op
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include "infinicore/ops/reshape_and_cache.hpp"
2+
3+
#include "../infiniop_impl.hpp"
4+
5+
namespace infinicore::op::reshape_and_cache_impl::infiniop {
6+
7+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, ReshapeAndCache, 100);
8+
9+
struct PlannedMeta {
10+
std::shared_ptr<Descriptor> descriptor;
11+
graph::GraphTensor workspace;
12+
graph::GraphTensor key;
13+
graph::GraphTensor value;
14+
graph::GraphTensor key_cache;
15+
graph::GraphTensor value_cache;
16+
graph::GraphTensor slot_mapping;
17+
graph::GraphTensor k_scale;
18+
graph::GraphTensor v_scale;
19+
std::string kv_cache_dtype;
20+
};
21+
22+
void *plan(Tensor &key,
23+
Tensor &value,
24+
Tensor &key_cache,
25+
Tensor &value_cache,
26+
Tensor &slot_mapping,
27+
const std::string &kv_cache_dtype,
28+
Tensor &k_scale,
29+
Tensor &v_scale) {
30+
size_t seed = hash_combine(key, value, key_cache, value_cache, slot_mapping);
31+
32+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
33+
Descriptor, descriptor, ReshapeAndCache,
34+
seed,
35+
key->desc(), value->desc(), key_cache->desc(), value_cache->desc(),
36+
slot_mapping->desc(), kv_cache_dtype.c_str());
37+
38+
INFINIOP_WORKSPACE_TENSOR(workspace, ReshapeAndCache, descriptor);
39+
40+
return new PlannedMeta{
41+
descriptor,
42+
graph::GraphTensor(workspace),
43+
graph::GraphTensor(key),
44+
graph::GraphTensor(value),
45+
graph::GraphTensor(key_cache),
46+
graph::GraphTensor(value_cache),
47+
graph::GraphTensor(slot_mapping),
48+
graph::GraphTensor(k_scale),
49+
graph::GraphTensor(v_scale),
50+
kv_cache_dtype};
51+
}
52+
53+
void run(void *planned_meta) {
54+
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
55+
56+
INFINICORE_CHECK_ERROR(
57+
infiniopReshapeAndCache(
58+
p->descriptor->desc,
59+
p->workspace->data(),
60+
p->workspace->numel(),
61+
p->key->data(),
62+
p->value->data(),
63+
p->key_cache->data(),
64+
p->value_cache->data(),
65+
p->slot_mapping->data(),
66+
p->kv_cache_dtype.c_str(),
67+
p->k_scale->data(),
68+
p->v_scale->data(),
69+
context::getStream()));
70+
}
71+
72+
void cleanup(void **planned_meta_ptr) {
73+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
74+
*planned_meta_ptr = nullptr;
75+
}
76+
77+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(ReshapeAndCache, &plan, &run, &cleanup);
78+
79+
} // namespace infinicore::op::reshape_and_cache_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "ops/paged_attention_prefill.hpp"
1717
#include "ops/paged_attention_v2.hpp"
1818
#include "ops/paged_caching.hpp"
19+
#include "ops/reshape_and_cache.hpp"
1920
#include "ops/random_sample.hpp"
2021
#include "ops/rearrange.hpp"
2122
#include "ops/rms_norm.hpp"
@@ -41,6 +42,7 @@ inline void bind(py::module &m) {
4142
bind_paged_attention_v2(m);
4243
bind_paged_attention_prefill(m);
4344
bind_paged_caching(m);
45+
bind_reshape_and_cache(m);
4446
bind_random_sample(m);
4547
bind_rearrange(m);
4648
bind_rms_norm(m);

0 commit comments

Comments
 (0)