diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index 214d20e3..bc373478 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -1,7 +1,7 @@ #include "kv_cache.hpp" #include "../utils.hpp" - +#include "infinicore/ops.hpp" #include namespace infinilm::cache { @@ -155,6 +155,7 @@ PagedKVCache::PagedKVCache( num_blocks_per_layer_ = config.max_kv_memory_bytes() / (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_) / block_size_ + / rank_num_layers_ / infinicore::dsize(dtype_); if (num_blocks_per_layer_ == 0) { throw std::runtime_error("Not enough memory for KV cache"); @@ -187,11 +188,78 @@ std::tuple PagedKVCache::update( const infinicore::Tensor &v, const infinicore::Tensor &slot_mapping) { + auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx); + + infinicore::op::paged_caching_(k, + v, + k_cache_layer, + v_cache_layer, + slot_mapping); + return {k_cache_layer, v_cache_layer}; +} + +std::tuple +PagedKVCache::get_paged_kv(size_t layer_idx) { auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); + return {k_cache_layer, v_cache_layer}; +} + +std::tuple +PagedKVCache::get_contiguous_kv( + size_t layer_idx, + const infinicore::Tensor block_tables, + const infinicore::Tensor cache_lens, + const infinicore::Tensor input_offsets, + size_t request_id) { + ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64); + ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64); + ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64); - /// @todo: implement paged cache update here + auto nreq = block_tables->size(0); + auto block_tables_cpu = block_tables->to(infinicore::Device::cpu()); + auto cache_lens_cpu = cache_lens->to(infinicore::Device::cpu()); + auto input_offsets_cpu = input_offsets->to(infinicore::Device::cpu()); + infinicore::context::syncDevice(); - return {k_cache_layer, v_cache_layer}; + // [num_blocks, num_rank_v_heads, block_size, v_dim] + auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx); + + auto req = request_id; + auto cache_lens_ptr = reinterpret_cast(cache_lens_cpu->data()); + auto input_offsets_ptr = reinterpret_cast(input_offsets_cpu->data()); + int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]); + + auto full_k = infinicore::Tensor::empty( + {num_rank_k_heads_, (size_t)total_len, k_dim_}, + k_cache_layer->dtype(), k_cache_layer->device()); + + auto full_v = infinicore::Tensor::empty( + {num_rank_v_heads_, (size_t)total_len, v_dim_}, + v_cache_layer->dtype(), v_cache_layer->device()); + + size_t nblocks = total_len / block_size_; + size_t r = total_len % block_size_; + + for (size_t b = 0; b < nblocks; b++) { + size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data())); + + full_k->narrow({{1, b * block_size_, block_size_}}) + ->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)); + full_v->narrow({{1, b * block_size_, block_size_}}) + ->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)); + } + + if (r > 0) { + size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data())); + + full_k->narrow({{1, nblocks * block_size_, r}}) + ->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}})); + full_v->narrow({{1, nblocks * block_size_, r}}) + ->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}})); + } + + return {full_k, full_v}; } + } // namespace infinilm::cache diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index bcce66e5..699530d0 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -113,7 +113,7 @@ class PagedKVCache final : public Cache { /** * @brief Update Paged KV cache at a given layer given slot info for each token. * - * @param layer_idx Which transformer layer + * @param layer_idx Which paged attention layer * @param k [num_rank_k_heads, seq_len, k_dim] * @param v [num_rank_v_heads, seq_len, v_dim] * @param slot_mapping [seq_len] @@ -128,7 +128,41 @@ class PagedKVCache final : public Cache { const infinicore::Tensor &v, const infinicore::Tensor &slot_mapping); - ~PagedKVCache() override = default; + /** + * @brief Get Paged KV cache at a given layer. + * + * @param layer_idx Which paged attention layer + * + * @return (full_k, full_v) + * full_k: [num_blocks, num_rank_k_heads, block_size, k_dim] + * full_v: [num_blocks, num_rank_v_heads, block_size, v_dim] + */ + std::tuple + get_paged_kv(size_t layer_idx); + + /** + * @brief Get contiguous KV cache at a given layer, given the request info + * among a continuous request batch. + * + * @param layer_idx Which paged attention layer + * @param block_tables [num_requests, max_blocks_per_request] + * @param cache_lens [num_requests] + * @param input_offsets [num_requests + 1] + * @param request_id Which request among a continuous batch of requests + * + * @return (full_k, full_v) + * full_k: [num_rank_k_heads, total_len, k_dim] + * full_v: [num_rank_v_heads, total_len, v_dim] + */ + std::tuple + get_contiguous_kv(size_t layer_idx, + const infinicore::Tensor block_tables, + const infinicore::Tensor cache_lens, + const infinicore::Tensor input_offsets, + size_t request_id = 0); + + ~PagedKVCache() override + = default; private: infinicore::Size k_dim_; diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 663d03ab..8863568a 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -56,8 +56,50 @@ std::vector> InferEng //------------------------------------------------------ // forward //------------------------------------------------------ -infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const { - return {input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping}; +infinilm::InfinilmModel::Input InferEngine::Input::to_model_input(infinicore::Device device) const { + + std::optional position_ids_on_device; + if (position_ids.has_value()) { + position_ids_on_device = position_ids.value()->to(device); + } + + std::optional cache_lengths_on_device; + if (cache_lengths.has_value()) { + if (block_tables.has_value()) { + cache_lengths_on_device = cache_lengths.value()->to(device); + } else { // @todo: only paged kv cache support device tensor so far + cache_lengths_on_device = cache_lengths.value(); + } + } + + std::optional input_lengths_on_device; + if (input_lengths.has_value()) { + input_lengths_on_device = input_lengths.value()->to(device); + } + + std::optional input_offsets_on_device; + if (input_offsets.has_value()) { + input_offsets_on_device = input_offsets.value()->to(device); + } + + std::optional block_tables_on_device; + if (block_tables.has_value()) { + block_tables_on_device = block_tables.value()->to(device); + } + + std::optional slot_mapping_on_device; + if (slot_mapping.has_value()) { + slot_mapping_on_device = slot_mapping.value()->to(device); + } + + return { + input_ids, // @todo: on device in the future + position_ids_on_device, + cache_lengths_on_device, + input_lengths_on_device, + input_offsets_on_device, + block_tables_on_device, + slot_mapping_on_device}; } InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 3c335b23..d8acb5a3 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -206,7 +206,7 @@ void RankWorker::thread_loop() { local_param_name = pending_param_name_; local_param = pending_param_; } else if (local_cmd == Command::RUN) { - local_args = pending_args_.to_model_input(); + local_args = pending_args_.to_model_input(rank_info_.device); } else if (local_cmd == Command::RESET_CACHE) { if (pending_cache_config_ != nullptr) { local_cache_config = pending_cache_config_->unique_copy(); @@ -254,13 +254,18 @@ void RankWorker::thread_loop() { auto random_val{pending_args_.random_val}; const auto &logits_shape{logits->shape()}; - const auto &batch_size{logits_shape[0]}; const auto &vocab_size{logits_shape[2]}; + const auto &total_len{logits_shape[1]}; + const auto &batch_size{logits_shape[0]}; + + auto n_req = pending_args_.input_offsets.value()->size(0); + int64_t *input_lengths = (int64_t *)pending_args_.input_lengths.value()->data(); + int64_t *input_offsets = (int64_t *)pending_args_.input_offsets.value()->data(); - auto output_ids{infinicore::Tensor::empty({batch_size}, infinicore::DataType::I32, rank_info_.device)}; + auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)}; - for (auto i{decltype(batch_size)(0)}; i < batch_size; ++i) { - auto score{logits->narrow({{0, i, 1}})->view({vocab_size})}; + for (auto i{decltype(n_req)(0)}; i < n_req; ++i) { + auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i] + input_lengths[i] - 1), 1}})->view({vocab_size})}; auto out{output_ids->narrow({{0, i, 1}})->view({})}; infinicore::op::random_sample_( out, score, random_val, top_p, top_k, temperature); diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 89640028..fbe0339f 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -47,7 +47,7 @@ class RankWorker { float random_val{0.1}; - infinilm::InfinilmModel::Input to_model_input() const; + infinilm::InfinilmModel::Input to_model_input(infinicore::Device device) const; }; struct Output { diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index 78be6a87..e0f7341e 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -1,5 +1,6 @@ #include "llama_attention.hpp" +#include "../../utils.hpp" #include "infinicore/nn/linear.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/ops.hpp" @@ -43,6 +44,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, } else { throw std::runtime_error("num_attention_heads / tp_size error."); } + scaling_ = 1.0f / std::sqrt(static_cast(head_dim_)); // Initialize projection layers INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_, @@ -52,17 +54,10 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, dtype, device, tp_rank, tp_size, rank_info.comm); } -infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, - const infinicore::Tensor &position_ids, - std::shared_ptr kv_cache, - std::optional cache_lengths, - std::optional input_lengths, - std::optional input_offsets, - std::optional block_tables, - std::optional slot_mapping) const { - if (!rotary_emb_) { - throw std::runtime_error("LlamaAttention: rotary_emb not configured"); - } +infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional cache_lengths) const { // Input shape: [batch, seq_len, hidden_size] auto hidden_states_mutable = hidden_states; auto shape = hidden_states->shape(); @@ -73,7 +68,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); // 2. Reshape for multi-head attention - // Reshape Q, K, V to include batch dimension // Python: query_states = self.q_proj(hidden_states).view(querys_shape) // The view operation requires the tensor to be contiguous in the required dimensions @@ -114,13 +108,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_lengths.value()); k_total = k_total_tmp; v_total = v_total_tmp; - } else if (auto paged_kv_cache = std::dynamic_pointer_cast(kv_cache)) { - auto [k_total_tmp, v_total_tmp] = paged_kv_cache->update(layer_idx_, k_permuted, v_permuted, slot_mapping.value()); - k_total = k_total_tmp; - v_total = v_total_tmp; - - /// @todo Implement paged attention here. - throw std::runtime_error("LlamaAttention: Paged attention not implemented"); } else { throw std::runtime_error("LlamaAttention: Unsupported kvcache type"); } @@ -134,8 +121,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len] - float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); - auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [bs * n_kv_head, ng * seq_len, total_seq_len] + auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len] auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len}); infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax); @@ -152,6 +138,119 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat return output; } +infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr paged_kv_cache, + std::optional cache_lengths, + std::optional input_lengths, + std::optional input_offsets, + std::optional block_tables, + std::optional slot_mapping) const { + ASSERT(block_tables.has_value()); + ASSERT(input_lengths.has_value()); + ASSERT(slot_mapping.has_value()); + + // Input shape: [batch, seq_len, hidden_size] + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + // Only support batchsize==1, all requests should be flattened along seqlen dimension + ASSERT_EQ(batch_size, 1); + // Decode only if total_len == num_requests + bool is_prefill = (seq_len != input_lengths.value()->shape()[0]); + + // 1. Project Q, K, V + auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + // 2. Reshape for multi-head attention + + // Reshape Q, K, V to include batch dimension + // Python: query_states = self.q_proj(hidden_states).view(querys_shape) + // The view operation requires the tensor to be contiguous in the required dimensions + auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_}); + + // 3. Prepare position_ids for RoPE - align with Python pattern + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids; + } else { + throw std::runtime_error("Unexpected position_ids shape"); + } + + // 4. Apply RoPE to Q and K + rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_q_head, head_dim] + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] + + // 5. Prepare KV caches + // Ensure contiguous after permute for F16 compatibility with cache operations + auto [k_total, v_total] = paged_kv_cache->update(layer_idx_, + k_reshaped, + v_reshaped, + slot_mapping.value()); + + // 6. Compute attention + infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); + + if (is_prefill) { + infinicore::op::paged_attention_prefill_( + attn_output, + q_reshaped, + k_total, + v_total, + block_tables.value(), + cache_lengths.value(), + input_lengths.value(), + input_offsets.value(), + std::nullopt, + scaling_); + + } else { + infinicore::op::paged_attention_( + attn_output, + q_reshaped, + k_total, + v_total, + block_tables.value(), + cache_lengths.value(), + std::nullopt, + scaling_); + } + + // 7. Project output + attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_}); + return o_proj_->forward(attn_output); +} + +infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional cache_lengths, + std::optional input_lengths, + std::optional input_offsets, + std::optional block_tables, + std::optional slot_mapping) const { + if (!rotary_emb_) { + throw std::runtime_error("LlamaAttention: rotary_emb not configured"); + } + + infinicore::Tensor output; + if (auto paged_kv_cache = std::dynamic_pointer_cast(kv_cache)) { + output = forward_paged_(hidden_states, position_ids, paged_kv_cache, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping); + } else { + + output = forward_(hidden_states, position_ids, kv_cache, cache_lengths); + } + return output; +} + void LlamaAttention::set_rotary_emb(const std::shared_ptr &rotary_emb) { rotary_emb_ = rotary_emb; } diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index 7c938b12..c88084a4 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -55,7 +55,7 @@ class LlamaAttention : public infinicore::nn::Module { std::optional input_lengths, std::optional input_offsets, std::optional block_tables, - std::optional slot_mappin) const; + std::optional slot_mapping) const; /** * @brief Get the layer index @@ -73,6 +73,21 @@ class LlamaAttention : public infinicore::nn::Module { size_t head_dim() const { return head_dim_; } size_t hidden_size() const { return hidden_size_; } +private: + infinicore::Tensor forward_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional cache_lengths) const; + + infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional cache_lengths, + std::optional input_lengths, + std::optional input_offsets, + std::optional block_tables, + std::optional slot_mapping) const; + protected: // Projection layers INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj); @@ -93,6 +108,8 @@ class LlamaAttention : public infinicore::nn::Module { bool use_bias_; // Bias for Q/K/V projections bool use_output_bias_; // Bias for output projection (o_proj) size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) + + float scaling_; }; } // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp index f7ba70d7..155a6918 100644 --- a/csrc/models/llama/llama_for_causal_lm.cpp +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -35,8 +35,7 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { auto slot_mapping = input.slot_mapping; // 1. Forward through base model to get hidden states - auto position_ids_device = position_ids->to(device_); - auto hidden_states = model_->forward(input_ids, position_ids_device, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping); + auto hidden_states = model_->forward(input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping); // 2. Apply language modeling head to get logits auto logits = lm_head_->forward(hidden_states); diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 4991226d..4f0f3fde 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -59,15 +59,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping); } - // 3. Apply final layer normalization to last token only (aligns with transformers) - // Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size] - auto shape = hidden_states->shape(); - size_t seq_len = shape[1]; - auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}}); - - auto normalized_last_token = norm_->forward(last_token); - - return normalized_last_token; + return norm_->forward(hidden_states); } void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 9d7b7848..65c956ad 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -90,6 +90,8 @@ inline void bind_infer_engine(py::module &m) { std::move(input_ids), std::move(position_ids), std::move(cache_lengths), + std::move(input_lengths), + std::move(input_offsets), std::move(block_tables), std::move(slot_mapping)}}; diff --git a/examples/jiuge.py b/examples/jiuge.py index 8b7d6bcd..f4e3a7bf 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -9,6 +9,7 @@ import time import os import numpy as np +from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python")) @@ -82,6 +83,18 @@ def get_args(): default=1, help="total rank for tensor parallel", ) + parser.add_argument( + "--enable-paged-attn", + action="store_true", + help="use paged cache", + ) + + parser.add_argument( + "--max-kvcache-size", + type=int, + default=8 * 1024 * 1024 * 1024, + help="max size (in bytes) allocated to paged kv cache", + ) return parser.parse_args() @@ -92,6 +105,7 @@ def test( max_new_tokens=100, infini_device=infinicore.device("cpu", 0), tp=1, + enable_paged_attn=False, ): model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -150,11 +164,21 @@ def test( "input_ids" ] # List: [[1, 1128, 526, 366, 29892]] - # 根据输入长度和最长输出长度创建KVCache - model.reset_cache( - 1 if prompts is str else len(prompts), - max_new_tokens + len(input_ids_list[0]), - ) + # ---------------------------------------------------------------------------- # + # 创建KVCache + # ---------------------------------------------------------------------------- # + if enable_paged_attn: + cache_config = PagedKVCacheConfig( + max_kv_memory_bytes=args.max_kvcache_size, block_size=16 + ) + else: + batch_size = 1 if prompts is str else len(prompts) + initial_capacity = max_new_tokens + len(input_ids_list[0]) + cache_config = StaticKVCacheConfig( + max_batch_size=batch_size, max_cache_len=initial_capacity + ) + + model.reset_cache(cache_config) # ---------------------------------------------------------------------------- # # 自回归生成 @@ -211,7 +235,7 @@ def test( max_new_tokens = args.max_new_tokens backend = args.backend tp = args.tp - + enable_paged_attn = args.enable_paged_attn if backend != "cpp": raise ValueError(f"Unsupported backend: {backend}.") @@ -223,4 +247,5 @@ def test( max_new_tokens, infini_device=infini_device, tp=tp, + enable_paged_attn=enable_paged_attn, ) diff --git a/python/infinilm/auto_config.py b/python/infinilm/auto_config.py index 5408fe0b..83bac52b 100644 --- a/python/infinilm/auto_config.py +++ b/python/infinilm/auto_config.py @@ -21,5 +21,7 @@ def from_pretrained(model_path): if config_dict["model_type"] == "llama": return LlamaConfig(**config_dict) + elif config_dict["model_type"] == "qwen2": + return LlamaConfig(**config_dict) raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") diff --git a/python/infinilm/cache/__init__.py b/python/infinilm/cache/__init__.py index b40add1c..a19e59f8 100644 --- a/python/infinilm/cache/__init__.py +++ b/python/infinilm/cache/__init__.py @@ -1,3 +1,3 @@ -from .cache import CacheConfig, StaticKVCacheConfig +from .cache import CacheConfig, StaticKVCacheConfig,PagedKVCacheConfig -__all__ = ["CacheConfig", "StaticKVCacheConfig"] +__all__ = ["CacheConfig", "StaticKVCacheConfig", "PagedKVCacheConfig"] diff --git a/python/infinilm/generation/utils.py b/python/infinilm/generation/utils.py index 00143231..36f54cc6 100644 --- a/python/infinilm/generation/utils.py +++ b/python/infinilm/generation/utils.py @@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype): return ctypes.c_int32 elif infini_dtype == infinicore.float32: return ctypes.c_float + elif infini_dtype == infinicore.int64: + return ctypes.c_int64 else: raise ValueError(f"Unsupported py_dtype: {infini_dtype}") diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 4be1008f..7de70789 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -4,7 +4,7 @@ import infinicore from infinilm.auto_config import AutoConfig -from infinilm.cache import StaticKVCacheConfig +from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig from infinilm.distributed import DistConfig from infinilm.lib import _infinilm @@ -18,6 +18,7 @@ class GenerationConfig: top_p: float = 1.0 eos_token_id: list[int] | None = None + stop_on_eos: bool = True class InferEngine(_infinilm.InferEngine): @@ -42,6 +43,8 @@ def __init__( self.use_cache = False + self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig) + def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -93,15 +96,11 @@ def generate(self, input_ids, generation_config, *, _measure_and_log_time=False) else: eos_token_id = generation_config.eos_token_id - # TODO: Remove the `to_numpy` calls and simplify the corresponding code. - batch_size, seq_len = input_ids.shape[:2] - - position_ids = infinicore.from_list( - [list(range(0, seq_len)) for _ in range(batch_size)], dtype=infinicore.int64 - ) - cache_lengths = infinicore.from_list([0], dtype=infinicore.int64) - + past_seq_len = 0 output_ids = [] + initial_batch_size, initial_seqlen = input_ids.shape[:2] + seq_len = initial_seqlen + batch_size = initial_batch_size if batch_size != 1 and generation_config.max_new_tokens is None: raise ValueError( @@ -111,14 +110,76 @@ def generate(self, input_ids, generation_config, *, _measure_and_log_time=False) if _measure_and_log_time: time_measurements = [] - for _ in range(0, generation_config.max_new_tokens): + for iter in range(0, generation_config.max_new_tokens): if _measure_and_log_time: start_time = time.perf_counter() + batch_size, seq_len = input_ids.shape[:2] + + if self.enable_paged_attn: + input_ids = input_ids.view([1, batch_size * seq_len]) + position_ids = infinicore.from_list( + list(range(past_seq_len, past_seq_len + seq_len)) * batch_size, + dtype=infinicore.int64, + ) + cache_lengths = infinicore.from_list( + [past_seq_len] * batch_size, dtype=infinicore.int64 + ) + input_lengths = infinicore.from_list( + [seq_len] * batch_size, dtype=infinicore.int64 + ) + + input_offsets = infinicore.from_list( + [seq_len * i for i in range(batch_size)], dtype=infinicore.int64 + ) + block_tables = infinicore.from_list( + [ + [ + i * batch_size + b + for i in range((past_seq_len + seq_len + 15) // 16) + ] + for b in range(batch_size) + ], + dtype=infinicore.int64, + ) + slot_mapping = infinicore.from_list( + [ + ((past_seq_len + i + 15) // 16) * batch_size + + b + + (past_seq_len + i + 15) % 16 + for i in range(seq_len) + for b in range(batch_size) + ], + dtype=infinicore.int64, + ) + else: + position_ids = infinicore.from_list( + [ + list(range(past_seq_len, past_seq_len + seq_len)) + for _ in range(batch_size) + ], + dtype=infinicore.int64, + ) + cache_lengths = infinicore.from_list( + [past_seq_len], dtype=infinicore.int64 + ) + input_lengths = infinicore.from_list( + [seq_len] * batch_size, dtype=infinicore.int64 + ) + input_offsets = infinicore.from_list( + [seq_len * i for i in range(batch_size)], dtype=infinicore.int64 + ) + block_tables = None + slot_mapping = None + output_id = self( - input_ids, + input_ids=input_ids, position_ids=position_ids, cache_lengths=cache_lengths, + input_lengths=input_lengths, + input_offsets=input_offsets, + block_tables=block_tables, + slot_mapping=slot_mapping, temperature=generation_config.temperature, top_k=generation_config.top_k, top_p=generation_config.top_p, @@ -127,24 +188,16 @@ def generate(self, input_ids, generation_config, *, _measure_and_log_time=False) output_ids.append(output_id) if ( - generation_config.max_new_tokens is not None + generation_config.stop_on_eos + and generation_config.max_new_tokens is not None and output_id.to_numpy()[0] in eos_token_id ): break - seq_len = position_ids.shape[-1] - input_ids = infinicore.from_list( [[output_id] for output_id in output_id.to_numpy().tolist()] ) - position_ids = infinicore.from_list( - [1 for _ in range(batch_size)], - dtype=position_ids.dtype, - device=position_ids.device, - ).view((batch_size, 1)) + position_ids.narrow(1, seq_len - 1, 1) - cache_lengths += infinicore.from_list( - [seq_len], dtype=cache_lengths.dtype, device=cache_lengths.device - ) + past_seq_len = past_seq_len + seq_len if _measure_and_log_time: end_time = time.perf_counter() @@ -156,23 +209,21 @@ def generate(self, input_ids, generation_config, *, _measure_and_log_time=False) f"\n\n\n Generation completed in {round(sum(time_measurements) * 1000, 2)} ms" ) print( - f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_measurements)}\n" + f" Batchsize={initial_batch_size} Per_Batch_Input_Len={initial_seqlen} Per_Batch_New_Tokens={len(time_measurements)}\n" ) print( - f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((batch_size * seq_len) / time_measurements[0], 2)}tok/s\n", + f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((initial_batch_size * initial_seqlen) / time_measurements[0], 2)}tok/s\n", ) if len(time_measurements) > 1: print( - f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n", + f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((initial_batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n", ) return output_ids - def reset_cache(self, batch_size: int, initial_capacity: int = 1024): + def reset_cache(self, cache_config): infinicore.sync_device() - - cache_config = StaticKVCacheConfig(batch_size, initial_capacity) - + self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig) super().reset_cache(cache_config) def state_dict_keyname(self):