Skip to content

Commit db19cc0

Browse files
issue/168 use n_blocks to init paged kv cache config, support fixed paged caching api
1 parent 831e8a6 commit db19cc0

File tree

5 files changed

+28
-40
lines changed

5 files changed

+28
-40
lines changed

csrc/cache/kv_cache.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ StaticKVCache::update(size_t layer_idx,
111111
// PagedKVCacheConfig
112112
// ==========================
113113
PagedKVCacheConfig::PagedKVCacheConfig(
114-
size_t max_kv_memory_bytes,
114+
size_t num_blocks,
115115
size_t block_size)
116-
: max_kv_memory_bytes_(max_kv_memory_bytes),
116+
: num_blocks_(num_blocks),
117117
block_size_(block_size) {
118118
}
119119

@@ -123,8 +123,8 @@ PagedKVCacheConfig::unique_copy() const {
123123
}
124124

125125
size_t
126-
PagedKVCacheConfig::max_kv_memory_bytes() const {
127-
return max_kv_memory_bytes_;
126+
PagedKVCacheConfig::num_blocks() const {
127+
return num_blocks_;
128128
}
129129

130130
size_t
@@ -151,16 +151,8 @@ PagedKVCache::PagedKVCache(
151151
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
152152
rank_num_layers_(num_layers),
153153
dtype_(dtype),
154+
num_blocks_per_layer_(config.num_blocks()),
154155
block_size_(config.block_size()) {
155-
num_blocks_per_layer_ = config.max_kv_memory_bytes()
156-
/ (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_)
157-
/ block_size_
158-
/ rank_num_layers_
159-
/ infinicore::dsize(dtype_);
160-
if (num_blocks_per_layer_ == 0) {
161-
throw std::runtime_error("Not enough memory for KV cache");
162-
}
163-
164156
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
165157
k_caches_ = infinicore::Tensor::empty(
166158
{rank_num_layers_,
@@ -190,11 +182,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
190182

191183
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
192184

193-
infinicore::op::paged_caching_(k,
194-
v,
195-
k_cache_layer,
196-
v_cache_layer,
197-
slot_mapping);
185+
infinicore::op::paged_caching_(
186+
k_cache_layer,
187+
v_cache_layer,
188+
k,
189+
v,
190+
slot_mapping);
198191
return {k_cache_layer, v_cache_layer};
199192
}
200193

csrc/cache/kv_cache.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,15 @@ class StaticKVCache final : public Cache {
8585
class PagedKVCacheConfig final : public CacheConfig {
8686
public:
8787
PagedKVCacheConfig(
88-
size_t max_kv_memory_bytes,
88+
size_t num_blocks,
8989
size_t block_size = 16);
9090

9191
std::unique_ptr<CacheConfig> unique_copy() const override;
92-
size_t max_kv_memory_bytes() const;
92+
size_t num_blocks() const;
9393
size_t block_size() const;
9494

9595
private:
96-
size_t max_kv_memory_bytes_;
96+
size_t num_blocks_;
9797
size_t block_size_;
9898
};
9999

csrc/pybind11/cache/cache.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ inline void bind_cache(py::module &m) {
3636
std::shared_ptr<infinilm::cache::PagedKVCacheConfig>>(m, "PagedKVCacheConfig")
3737
.def(
3838
py::init<size_t, size_t>(),
39-
py::arg("max_kv_memory_bytes"),
39+
py::arg("num_blocks"),
4040
py::arg("block_size") = 16)
4141
.def(
42-
"max_kv_memory_bytes",
43-
&infinilm::cache::PagedKVCacheConfig::max_kv_memory_bytes)
42+
"num_blocks",
43+
&infinilm::cache::PagedKVCacheConfig::num_blocks)
4444
.def(
4545
"block_size",
4646
&infinilm::cache::PagedKVCacheConfig::block_size)

examples/jiuge.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,6 @@ def get_args():
8989
help="use paged cache",
9090
)
9191

92-
parser.add_argument(
93-
"--max-kvcache-size",
94-
type=int,
95-
default=8 * 1024 * 1024 * 1024,
96-
help="max size (in bytes) allocated to paged kv cache",
97-
)
98-
9992
return parser.parse_args()
10093

10194

@@ -109,7 +102,7 @@ def test(
109102
):
110103
model_path = os.path.expanduser(model_path)
111104
# ---------------------------------------------------------------------------- #
112-
# 创建模型,
105+
# Create Model
113106
# ---------------------------------------------------------------------------- #
114107
model = InferEngine(
115108
model_path,
@@ -118,12 +111,12 @@ def test(
118111
)
119112

120113
# ---------------------------------------------------------------------------- #
121-
# 加载权重
114+
# Load Weights
122115
# ---------------------------------------------------------------------------- #
123116
load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype)
124117

125118
# ---------------------------------------------------------------------------- #
126-
# 创建 tokenizer
119+
# create tokenizer
127120
# ---------------------------------------------------------------------------- #
128121
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
129122

@@ -146,7 +139,7 @@ def test(
146139
)
147140

148141
# ---------------------------------------------------------------------------- #
149-
# token编码
142+
# tokenize
150143
# ---------------------------------------------------------------------------- #
151144
# prompt = "山东最高的山是?"
152145
if isinstance(prompts, str):
@@ -165,11 +158,13 @@ def test(
165158
] # List: [[1, 1128, 526, 366, 29892]]
166159

167160
# ---------------------------------------------------------------------------- #
168-
# 创建KVCache
161+
# Create KVCache
169162
# ---------------------------------------------------------------------------- #
170163
if enable_paged_attn:
164+
batch_size = 1 if prompts is str else len(prompts)
165+
max_total_tokens = max_new_tokens + len(input_ids_list[0])
171166
cache_config = PagedKVCacheConfig(
172-
max_kv_memory_bytes=args.max_kvcache_size, block_size=16
167+
num_blocks=(max_total_tokens // 16 + 1) * batch_size, block_size=16
173168
)
174169
else:
175170
batch_size = 1 if prompts is str else len(prompts)
@@ -181,7 +176,7 @@ def test(
181176
model.reset_cache(cache_config)
182177

183178
# ---------------------------------------------------------------------------- #
184-
# 自回归生成
179+
# Generate
185180
# ---------------------------------------------------------------------------- #
186181
print(input_contents[0], end="", flush=True)
187182
input_ids_infini = infinicore.from_list(input_ids_list)

python/infinilm/cache/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0):
1616
class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig):
1717
def __init__(
1818
self,
19-
max_kv_memory_bytes: int,
19+
num_blocks: int,
2020
block_size: int = 16,
2121
):
2222
_infinilm.PagedKVCacheConfig.__init__(
2323
self,
24-
max_kv_memory_bytes,
24+
num_blocks,
2525
block_size,
2626
)

0 commit comments

Comments
 (0)