diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index 214d20e3..9d66a2dd 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -1,7 +1,7 @@ #include "kv_cache.hpp" #include "../utils.hpp" - +#include "infinicore/ops.hpp" #include namespace infinilm::cache { @@ -80,12 +80,12 @@ std::tuple StaticKVCache::update(size_t layer_idx, const infinicore::Tensor &k, const infinicore::Tensor &v, - const infinicore::Tensor &cache_lengths) { + const infinicore::Tensor &past_sequence_lengths) { ASSERT(layer_idx < rank_num_layers_); auto batch_size = k->size(0); auto update_len = k->size(2); - size_t cache_pos = reinterpret_cast(cache_lengths->to(infinicore::Device::cpu())->data())[0]; + size_t cache_pos = reinterpret_cast(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]; auto result_len = cache_pos + update_len; ASSERT(result_len <= cache_len_); @@ -111,9 +111,9 @@ StaticKVCache::update(size_t layer_idx, // PagedKVCacheConfig // ========================== PagedKVCacheConfig::PagedKVCacheConfig( - size_t max_kv_memory_bytes, + size_t num_blocks, size_t block_size) - : max_kv_memory_bytes_(max_kv_memory_bytes), + : num_blocks_(num_blocks), block_size_(block_size) { } @@ -123,8 +123,8 @@ PagedKVCacheConfig::unique_copy() const { } size_t -PagedKVCacheConfig::max_kv_memory_bytes() const { - return max_kv_memory_bytes_; +PagedKVCacheConfig::num_blocks() const { + return num_blocks_; } size_t @@ -151,15 +151,8 @@ PagedKVCache::PagedKVCache( num_rank_v_heads_(num_v_heads / rank_info.tp_size), rank_num_layers_(num_layers), dtype_(dtype), + num_blocks_per_layer_(config.num_blocks()), block_size_(config.block_size()) { - num_blocks_per_layer_ = config.max_kv_memory_bytes() - / (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_) - / block_size_ - / infinicore::dsize(dtype_); - if (num_blocks_per_layer_ == 0) { - throw std::runtime_error("Not enough memory for KV cache"); - } - // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim] k_caches_ = infinicore::Tensor::empty( {rank_num_layers_, @@ -187,11 +180,79 @@ std::tuple PagedKVCache::update( const infinicore::Tensor &v, const infinicore::Tensor &slot_mapping) { + auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx); + + infinicore::op::paged_caching_( + k_cache_layer, + v_cache_layer, + k, + v, + slot_mapping); + return {k_cache_layer, v_cache_layer}; +} + +std::tuple +PagedKVCache::get_paged_kv(size_t layer_idx) { auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); + return {k_cache_layer, v_cache_layer}; +} - /// @todo: implement paged cache update here +std::tuple +PagedKVCache::get_contiguous_kv( + size_t layer_idx, + const infinicore::Tensor block_tables, + const infinicore::Tensor cache_lens, + const infinicore::Tensor input_offsets, + size_t request_id) { + ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64); + ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64); + ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64); - return {k_cache_layer, v_cache_layer}; + auto nreq = block_tables->size(0); + auto block_tables_cpu = block_tables->to(infinicore::Device::cpu()); + auto cache_lens_cpu = cache_lens->to(infinicore::Device::cpu()); + auto input_offsets_cpu = input_offsets->to(infinicore::Device::cpu()); + infinicore::context::syncDevice(); + + // [num_blocks, num_rank_v_heads, block_size, v_dim] + auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx); + + auto req = request_id; + auto cache_lens_ptr = reinterpret_cast(cache_lens_cpu->data()); + auto input_offsets_ptr = reinterpret_cast(input_offsets_cpu->data()); + int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]); + + auto full_k = infinicore::Tensor::empty( + {num_rank_k_heads_, (size_t)total_len, k_dim_}, + k_cache_layer->dtype(), k_cache_layer->device()); + + auto full_v = infinicore::Tensor::empty( + {num_rank_v_heads_, (size_t)total_len, v_dim_}, + v_cache_layer->dtype(), v_cache_layer->device()); + + size_t nblocks = total_len / block_size_; + size_t r = total_len % block_size_; + + for (size_t b = 0; b < nblocks; b++) { + size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data())); + + full_k->narrow({{1, b * block_size_, block_size_}}) + ->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)); + full_v->narrow({{1, b * block_size_, block_size_}}) + ->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)); + } + + if (r > 0) { + size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data())); + + full_k->narrow({{1, nblocks * block_size_, r}}) + ->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}})); + full_v->narrow({{1, nblocks * block_size_, r}}) + ->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}})); + } + + return {full_k, full_v}; } + } // namespace infinilm::cache diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index bcce66e5..54b84cc1 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -61,7 +61,7 @@ class StaticKVCache final : public Cache { update(size_t layer_idx, const infinicore::Tensor &k, const infinicore::Tensor &v, - const infinicore::Tensor &cache_lengths); + const infinicore::Tensor &past_sequence_lengths); ~StaticKVCache() override = default; @@ -85,15 +85,15 @@ class StaticKVCache final : public Cache { class PagedKVCacheConfig final : public CacheConfig { public: PagedKVCacheConfig( - size_t max_kv_memory_bytes, + size_t num_blocks, size_t block_size = 16); std::unique_ptr unique_copy() const override; - size_t max_kv_memory_bytes() const; + size_t num_blocks() const; size_t block_size() const; private: - size_t max_kv_memory_bytes_; + size_t num_blocks_; size_t block_size_; }; @@ -113,7 +113,7 @@ class PagedKVCache final : public Cache { /** * @brief Update Paged KV cache at a given layer given slot info for each token. * - * @param layer_idx Which transformer layer + * @param layer_idx Which paged attention layer * @param k [num_rank_k_heads, seq_len, k_dim] * @param v [num_rank_v_heads, seq_len, v_dim] * @param slot_mapping [seq_len] @@ -128,7 +128,41 @@ class PagedKVCache final : public Cache { const infinicore::Tensor &v, const infinicore::Tensor &slot_mapping); - ~PagedKVCache() override = default; + /** + * @brief Get Paged KV cache at a given layer. + * + * @param layer_idx Which paged attention layer + * + * @return (full_k, full_v) + * full_k: [num_blocks, num_rank_k_heads, block_size, k_dim] + * full_v: [num_blocks, num_rank_v_heads, block_size, v_dim] + */ + std::tuple + get_paged_kv(size_t layer_idx); + + /** + * @brief Get contiguous KV cache at a given layer, given the request info + * among a continuous request batch. + * + * @param layer_idx Which paged attention layer + * @param block_tables [num_requests, max_blocks_per_request] + * @param cache_lens [num_requests] + * @param input_offsets [num_requests + 1] + * @param request_id Which request among a continuous batch of requests + * + * @return (full_k, full_v) + * full_k: [num_rank_k_heads, total_len, k_dim] + * full_v: [num_rank_v_heads, total_len, v_dim] + */ + std::tuple + get_contiguous_kv(size_t layer_idx, + const infinicore::Tensor block_tables, + const infinicore::Tensor cache_lens, + const infinicore::Tensor input_offsets, + size_t request_id = 0); + + ~PagedKVCache() override + = default; private: infinicore::Size k_dim_; diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 663d03ab..482117c0 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -56,8 +56,23 @@ std::vector> InferEng //------------------------------------------------------ // forward //------------------------------------------------------ -infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const { - return {input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping}; +infinilm::InfinilmModel::Input +InferEngine::Input::to_model_input(infinicore::Device device) const { + + auto to_device = [&](const std::optional &t) + -> std::optional { + return t.has_value() ? t.value()->to(device) : t; + }; + + return { + input_ids, // @todo: on device in the future + to_device(position_ids), + past_sequence_lengths, // @todo: on device in the future + to_device(total_sequence_lengths), + to_device(input_offsets), + to_device(block_tables), + to_device(slot_mapping), + }; } InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 3c335b23..003fb265 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -188,7 +188,7 @@ void RankWorker::thread_loop() { Command local_cmd = Command::INIT; std::string local_param_name; infinicore::Tensor local_param; - InfinilmModel::Input local_args; + Input local_args; std::unique_ptr local_cache_config; // Wait for a job or exit @@ -206,7 +206,7 @@ void RankWorker::thread_loop() { local_param_name = pending_param_name_; local_param = pending_param_; } else if (local_cmd == Command::RUN) { - local_args = pending_args_.to_model_input(); + local_args = pending_args_; } else if (local_cmd == Command::RESET_CACHE) { if (pending_cache_config_ != nullptr) { local_cache_config = pending_cache_config_->unique_copy(); @@ -244,23 +244,28 @@ void RankWorker::thread_loop() { { std::lock_guard lk(mutex_); - auto logits{model_->forward(local_args).logits}; - + auto model_args = local_args.to_model_input(rank_info_.device); + // Forward calculation + auto logits{model_->forward(model_args).logits}; + // Random sampling (rank 0 only) if (rank_info_.tp_rank == 0) { - // Perform random sampling. - auto temperature{pending_args_.temperature}; - auto top_p{pending_args_.top_p}; - auto top_k{pending_args_.top_k}; - auto random_val{pending_args_.random_val}; + auto temperature{local_args.temperature}; + auto top_p{local_args.top_p}; + auto top_k{local_args.top_k}; + auto random_val{local_args.random_val}; const auto &logits_shape{logits->shape()}; - const auto &batch_size{logits_shape[0]}; const auto &vocab_size{logits_shape[2]}; + const auto &total_len{logits_shape[1]}; + const auto &batch_size{logits_shape[0]}; + + auto n_req = local_args.input_offsets.value()->size(0) - 1; + int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data(); - auto output_ids{infinicore::Tensor::empty({batch_size}, infinicore::DataType::I32, rank_info_.device)}; + auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)}; - for (auto i{decltype(batch_size)(0)}; i < batch_size; ++i) { - auto score{logits->narrow({{0, i, 1}})->view({vocab_size})}; + for (auto i{decltype(n_req)(0)}; i < n_req; ++i) { + auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i + 1] - 1), 1}})->view({vocab_size})}; auto out{output_ids->narrow({{0, i, 1}})->view({})}; infinicore::op::random_sample_( out, score, random_val, top_p, top_k, temperature); diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 89640028..98bb4b87 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -29,9 +29,9 @@ class RankWorker { /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`. std::optional position_ids; /// Past Lengths of cached sequence for each request, of shape `[num_requests]`. - std::optional cache_lengths; - /// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`. - std::optional input_lengths; + std::optional past_sequence_lengths; + /// ToTal Lengths for each request sequence, of shape `[num_requests]`. + std::optional total_sequence_lengths; /// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`. std::optional input_offsets; /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. @@ -47,7 +47,7 @@ class RankWorker { float random_val{0.1}; - infinilm::InfinilmModel::Input to_model_input() const; + infinilm::InfinilmModel::Input to_model_input(infinicore::Device device) const; }; struct Output { diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index 3eb54249..4cad3b6c 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -23,10 +23,10 @@ class InfinilmModel : public infinicore::nn::Module { /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`. std::optional position_ids; /// Past Lengths of cached sequence for each request, of shape `[num_requests]`. - std::optional cache_lengths; - /// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`. - std::optional input_lengths; - /// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`. + std::optional past_sequence_lengths; + /// ToTal Lengths for each request sequence, of shape `[num_requests]`. + std::optional total_sequence_lengths; + /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. std::optional input_offsets; /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. std::optional block_tables; diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index 78be6a87..c8f2d71d 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -1,5 +1,6 @@ #include "llama_attention.hpp" +#include "../../utils.hpp" #include "infinicore/nn/linear.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/ops.hpp" @@ -43,6 +44,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, } else { throw std::runtime_error("num_attention_heads / tp_size error."); } + scaling_ = 1.0f / std::sqrt(static_cast(head_dim_)); // Initialize projection layers INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_, @@ -52,17 +54,11 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, dtype, device, tp_rank, tp_size, rank_info.comm); } -infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, - const infinicore::Tensor &position_ids, - std::shared_ptr kv_cache, - std::optional cache_lengths, - std::optional input_lengths, - std::optional input_offsets, - std::optional block_tables, - std::optional slot_mapping) const { - if (!rotary_emb_) { - throw std::runtime_error("LlamaAttention: rotary_emb not configured"); - } +infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths) const { // Input shape: [batch, seq_len, hidden_size] auto hidden_states_mutable = hidden_states; auto shape = hidden_states->shape(); @@ -73,7 +69,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); // 2. Reshape for multi-head attention - // Reshape Q, K, V to include batch dimension // Python: query_states = self.q_proj(hidden_states).view(querys_shape) // The view operation requires the tensor to be contiguous in the required dimensions @@ -111,16 +106,9 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat k_total = k_permuted; v_total = v_permuted; } else if (auto static_kv_cache = std::dynamic_pointer_cast(kv_cache)) { - auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_lengths.value()); - k_total = k_total_tmp; - v_total = v_total_tmp; - } else if (auto paged_kv_cache = std::dynamic_pointer_cast(kv_cache)) { - auto [k_total_tmp, v_total_tmp] = paged_kv_cache->update(layer_idx_, k_permuted, v_permuted, slot_mapping.value()); + auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, past_sequence_lengths.value()); k_total = k_total_tmp; v_total = v_total_tmp; - - /// @todo Implement paged attention here. - throw std::runtime_error("LlamaAttention: Paged attention not implemented"); } else { throw std::runtime_error("LlamaAttention: Unsupported kvcache type"); } @@ -134,8 +122,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len] - float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); - auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [bs * n_kv_head, ng * seq_len, total_seq_len] + auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len] auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len}); infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax); @@ -152,6 +139,116 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat return output; } +infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr paged_kv_cache, + std::optional total_sequence_lengths, + std::optional input_offsets, + std::optional block_tables, + std::optional slot_mapping) const { + ASSERT(block_tables.has_value()); + ASSERT(slot_mapping.has_value()); + + // Input shape: [batch, seq_len, hidden_size] + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + // Only support batchsize==1, all requests should be flattened along seqlen dimension + ASSERT_EQ(batch_size, 1); + // Decode only if total_len == num_requests + bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); + + // 1. Project Q, K, V + auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + // 2. Reshape for multi-head attention + + // Reshape Q, K, V to include batch dimension + // Python: query_states = self.q_proj(hidden_states).view(querys_shape) + // The view operation requires the tensor to be contiguous in the required dimensions + auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_}); + + // 3. Prepare position_ids for RoPE - align with Python pattern + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids; + } else { + throw std::runtime_error("Unexpected position_ids shape"); + } + + // 4. Apply RoPE to Q and K + rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_q_head, head_dim] + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] + + // 5. Prepare KV caches + // Ensure contiguous after permute for F16 compatibility with cache operations + auto [k_total, v_total] = paged_kv_cache->update(layer_idx_, + k_reshaped, + v_reshaped, + slot_mapping.value()); + + // 6. Compute attention + infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); + + if (is_prefill) { + infinicore::op::paged_attention_prefill_( + attn_output, + q_reshaped, + k_total, + v_total, + block_tables.value(), + total_sequence_lengths.value(), + input_offsets.value(), + std::nullopt, + scaling_); + + } else { + infinicore::op::paged_attention_( + attn_output, + q_reshaped, + k_total, + v_total, + block_tables.value(), + total_sequence_lengths.value(), + std::nullopt, + scaling_); + } + + // 7. Project output + attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_}); + return o_proj_->forward(attn_output); +} + +infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, + std::optional input_offsets, + std::optional block_tables, + std::optional slot_mapping) const { + if (!rotary_emb_) { + throw std::runtime_error("LlamaAttention: rotary_emb not configured"); + } + + infinicore::Tensor output; + if (auto paged_kv_cache = std::dynamic_pointer_cast(kv_cache)) { + output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, block_tables, slot_mapping); + } else { + + output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths); + } + return output; +} + void LlamaAttention::set_rotary_emb(const std::shared_ptr &rotary_emb) { rotary_emb_ = rotary_emb; } diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index 7c938b12..d732d107 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -51,11 +51,11 @@ class LlamaAttention : public infinicore::nn::Module { infinicore::Tensor forward(const infinicore::Tensor &hidden_states, const infinicore::Tensor &position_ids, std::shared_ptr kv_cache, - std::optional cache_lengths, - std::optional input_lengths, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, std::optional input_offsets, std::optional block_tables, - std::optional slot_mappin) const; + std::optional slot_mapping) const; /** * @brief Get the layer index @@ -73,6 +73,21 @@ class LlamaAttention : public infinicore::nn::Module { size_t head_dim() const { return head_dim_; } size_t hidden_size() const { return hidden_size_; } +private: + infinicore::Tensor forward_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths) const; + + infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional total_sequence_lengths, + std::optional input_offsets, + std::optional block_tables, + std::optional slot_mapping) const; + protected: // Projection layers INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj); @@ -93,6 +108,8 @@ class LlamaAttention : public infinicore::nn::Module { bool use_bias_; // Bias for Q/K/V projections bool use_output_bias_; // Bias for output projection (o_proj) size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) + + float scaling_; }; } // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_decoder_layer.cpp b/csrc/models/llama/llama_decoder_layer.cpp index f0246fe6..35a1acab 100644 --- a/csrc/models/llama/llama_decoder_layer.cpp +++ b/csrc/models/llama/llama_decoder_layer.cpp @@ -26,8 +26,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, const infinicore::Tensor &position_ids, std::shared_ptr kv_cache, - std::optional cache_lengths, - std::optional input_lengths, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, std::optional input_offsets, std::optional block_tables, std::optional slot_mapping) const { @@ -38,7 +38,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s auto normed_states = input_layernorm_->forward(hidden_states); // 2. Self-attention with residual connection - auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping); + auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping); // Add residual: hidden_states = hidden_states + attn_output auto output = infinicore::op::add(residual, attn_output); diff --git a/csrc/models/llama/llama_decoder_layer.hpp b/csrc/models/llama/llama_decoder_layer.hpp index 9999e287..4ded50a7 100644 --- a/csrc/models/llama/llama_decoder_layer.hpp +++ b/csrc/models/llama/llama_decoder_layer.hpp @@ -49,8 +49,8 @@ class LlamaDecoderLayer : public infinicore::nn::Module { infinicore::Tensor forward(const infinicore::Tensor &hidden_states, const infinicore::Tensor &position_ids, std::shared_ptr kv_cache, - std::optional cache_lengths, - std::optional input_lengths, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, std::optional input_offsets, std::optional block_tables, std::optional slot_mappin) const; diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp index f7ba70d7..6ce1fd98 100644 --- a/csrc/models/llama/llama_for_causal_lm.cpp +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -28,15 +28,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { auto input_ids = input.input_ids.value(); auto position_ids = input.position_ids.value(); - auto cache_lengths = input.cache_lengths; - auto input_lengths = input.input_lengths; + auto past_sequence_lengths = input.past_sequence_lengths; + auto total_sequence_length = input.total_sequence_lengths; auto input_offsets = input.input_offsets; auto block_tables = input.block_tables; auto slot_mapping = input.slot_mapping; // 1. Forward through base model to get hidden states - auto position_ids_device = position_ids->to(device_); - auto hidden_states = model_->forward(input_ids, position_ids_device, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping); + auto hidden_states = model_->forward( + input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping); // 2. Apply language modeling head to get logits auto logits = lm_head_->forward(hidden_states); diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 4991226d..95193ec6 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -45,8 +45,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config, infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids, - std::optional cache_lengths, - std::optional input_lengths, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, std::optional input_offsets, std::optional block_tables, std::optional slot_mapping) const { @@ -56,18 +56,10 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, // 2. Process through all decoder layers size_t num_layers = layers_.size(); for (size_t i = 0; i < num_layers; ++i) { - hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping); + hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping); } - // 3. Apply final layer normalization to last token only (aligns with transformers) - // Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size] - auto shape = hidden_states->shape(); - size_t seq_len = shape[1]; - auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}}); - - auto normalized_last_token = norm_->forward(last_token); - - return normalized_last_token; + return norm_->forward(hidden_states); } void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { diff --git a/csrc/models/llama/llama_model.hpp b/csrc/models/llama/llama_model.hpp index 039208f3..5a008b0f 100644 --- a/csrc/models/llama/llama_model.hpp +++ b/csrc/models/llama/llama_model.hpp @@ -48,15 +48,15 @@ class LlamaModel : public infinicore::nn::Module { * @param input_ids Token IDs tensor of shape [batch, seq_len]. Batch is 1 when continuous batch is used, * and tokens from all requests are concatenated along seq_len dimension. * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] - * @param cache_lengths Cache positions tensor of shape [n_req] - * @param input_lengths Input lengths tensor in a continuous batch of shape [n_req] - * @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req] + * @param past_sequence_lengths Cache positions tensor of shape [n_req] + * @param total_sequence_lengths Total sequence lengths tensor of shape [n_req] + * @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req + 1] * @return Output tensor of shape [batch, seq_len, hidden_size] */ infinicore::Tensor forward(const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids, - std::optional cache_lengths, - std::optional input_lengths, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, std::optional input_offsets, std::optional block_tables, std::optional slot_mapping) const; diff --git a/csrc/pybind11/cache/cache.hpp b/csrc/pybind11/cache/cache.hpp index 5c155b04..d9f1985c 100644 --- a/csrc/pybind11/cache/cache.hpp +++ b/csrc/pybind11/cache/cache.hpp @@ -36,11 +36,11 @@ inline void bind_cache(py::module &m) { std::shared_ptr>(m, "PagedKVCacheConfig") .def( py::init(), - py::arg("max_kv_memory_bytes"), + py::arg("num_blocks"), py::arg("block_size") = 16) .def( - "max_kv_memory_bytes", - &infinilm::cache::PagedKVCacheConfig::max_kv_memory_bytes) + "num_blocks", + &infinilm::cache::PagedKVCacheConfig::num_blocks) .def( "block_size", &infinilm::cache::PagedKVCacheConfig::block_size) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 9d7b7848..5ac38d70 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -80,28 +80,48 @@ inline void bind_infer_engine(py::module &m) { py::init([]( std::optional input_ids, std::optional position_ids, - std::optional cache_lengths, - std::optional input_lengths, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, std::optional input_offsets, std::optional block_tables, std::optional slot_mapping, py::kwargs kwargs) { - auto input{InferEngine::Input{ + InferEngine::Input input{ std::move(input_ids), std::move(position_ids), - std::move(cache_lengths), + std::move(past_sequence_lengths), + std::move(total_sequence_lengths), + std::move(input_offsets), std::move(block_tables), - std::move(slot_mapping)}}; + std::move(slot_mapping), + }; - if (kwargs) { - if (kwargs.contains("temperature")) { - input.temperature = kwargs["temperature"].cast(); - } - if (kwargs.contains("top_k")) { - input.top_k = kwargs["top_k"].cast(); + // Explicit defaults + input.temperature = 1.0f; + input.top_p = 1.0f; + input.top_k = 1; + + // Allowed keyword arguments + static const std::unordered_set allowed_kwargs = { + "temperature", + "top_p", + "top_k", + }; + + for (auto &item : kwargs) { + const std::string key = py::cast(item.first); + + if (allowed_kwargs.find(key) == allowed_kwargs.end()) { + throw py::value_error( + "InferEngine.Input got an unexpected keyword argument '" + key + "'"); } - if (kwargs.contains("top_p")) { - input.top_p = kwargs["top_p"].cast(); + + if (key == "temperature") { + input.temperature = py::cast(item.second); + } else if (key == "top_p") { + input.top_p = py::cast(item.second); + } else if (key == "top_k") { + input.top_k = py::cast(item.second); } } @@ -109,18 +129,21 @@ inline void bind_infer_engine(py::module &m) { }), py::arg("input_ids") = std::nullopt, py::arg("position_ids") = std::nullopt, - py::arg("cache_lengths") = std::nullopt, - py::arg("input_lengths") = std::nullopt, + py::arg("past_sequence_lengths") = std::nullopt, + py::arg("total_sequence_lengths") = std::nullopt, py::arg("input_offsets") = std::nullopt, py::arg("block_tables") = std::nullopt, py::arg("slot_mapping") = std::nullopt) .def_readwrite("input_ids", &InferEngine::Input::input_ids) .def_readwrite("position_ids", &InferEngine::Input::position_ids) - .def_readwrite("cache_lengths", &InferEngine::Input::cache_lengths) - .def_readwrite("input_lengths", &InferEngine::Input::input_lengths) + .def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths) + .def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths) .def_readwrite("input_offsets", &InferEngine::Input::input_offsets) .def_readwrite("block_tables", &InferEngine::Input::block_tables) - .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping); + .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping) + .def_readwrite("temperature", &InferEngine::Input::temperature) + .def_readwrite("top_k", &InferEngine::Input::top_k) + .def_readwrite("top_p", &InferEngine::Input::top_p); py::class_(infer_engine, "Output") .def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor"); diff --git a/examples/jiuge.py b/examples/jiuge.py index 8b7d6bcd..c1ad567e 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -9,6 +9,7 @@ import time import os import numpy as np +from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python")) @@ -82,6 +83,11 @@ def get_args(): default=1, help="total rank for tensor parallel", ) + parser.add_argument( + "--enable-paged-attn", + action="store_true", + help="use paged cache", + ) return parser.parse_args() @@ -92,10 +98,11 @@ def test( max_new_tokens=100, infini_device=infinicore.device("cpu", 0), tp=1, + enable_paged_attn=False, ): model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # - # 创建模型, + # Create Model # ---------------------------------------------------------------------------- # model = InferEngine( model_path, @@ -104,12 +111,12 @@ def test( ) # ---------------------------------------------------------------------------- # - # 加载权重 + # Load Weights # ---------------------------------------------------------------------------- # load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype) # ---------------------------------------------------------------------------- # - # 创建 tokenizer + # create tokenizer # ---------------------------------------------------------------------------- # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) @@ -132,7 +139,7 @@ def test( ) # ---------------------------------------------------------------------------- # - # token编码 + # tokenize # ---------------------------------------------------------------------------- # # prompt = "山东最高的山是?" if isinstance(prompts, str): @@ -150,14 +157,26 @@ def test( "input_ids" ] # List: [[1, 1128, 526, 366, 29892]] - # 根据输入长度和最长输出长度创建KVCache - model.reset_cache( - 1 if prompts is str else len(prompts), - max_new_tokens + len(input_ids_list[0]), - ) + # ---------------------------------------------------------------------------- # + # Create KVCache + # ---------------------------------------------------------------------------- # + if enable_paged_attn: + batch_size = 1 if prompts is str else len(prompts) + max_total_tokens = max_new_tokens + len(input_ids_list[0]) + cache_config = PagedKVCacheConfig( + num_blocks=(max_total_tokens // 16 + 1) * batch_size, block_size=16 + ) + else: + batch_size = 1 if prompts is str else len(prompts) + initial_capacity = max_new_tokens + len(input_ids_list[0]) + cache_config = StaticKVCacheConfig( + max_batch_size=batch_size, max_cache_len=initial_capacity + ) + + model.reset_cache(cache_config) # ---------------------------------------------------------------------------- # - # 自回归生成 + # Generate # ---------------------------------------------------------------------------- # print(input_contents[0], end="", flush=True) input_ids_infini = infinicore.from_list(input_ids_list) @@ -211,7 +230,7 @@ def test( max_new_tokens = args.max_new_tokens backend = args.backend tp = args.tp - + enable_paged_attn = args.enable_paged_attn if backend != "cpp": raise ValueError(f"Unsupported backend: {backend}.") @@ -223,4 +242,5 @@ def test( max_new_tokens, infini_device=infini_device, tp=tp, + enable_paged_attn=enable_paged_attn, ) diff --git a/python/infinilm/auto_config.py b/python/infinilm/auto_config.py index 5408fe0b..83bac52b 100644 --- a/python/infinilm/auto_config.py +++ b/python/infinilm/auto_config.py @@ -21,5 +21,7 @@ def from_pretrained(model_path): if config_dict["model_type"] == "llama": return LlamaConfig(**config_dict) + elif config_dict["model_type"] == "qwen2": + return LlamaConfig(**config_dict) raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") diff --git a/python/infinilm/cache/__init__.py b/python/infinilm/cache/__init__.py index b40add1c..a19e59f8 100644 --- a/python/infinilm/cache/__init__.py +++ b/python/infinilm/cache/__init__.py @@ -1,3 +1,3 @@ -from .cache import CacheConfig, StaticKVCacheConfig +from .cache import CacheConfig, StaticKVCacheConfig,PagedKVCacheConfig -__all__ = ["CacheConfig", "StaticKVCacheConfig"] +__all__ = ["CacheConfig", "StaticKVCacheConfig", "PagedKVCacheConfig"] diff --git a/python/infinilm/cache/cache.py b/python/infinilm/cache/cache.py index 31b8f49d..e0fe8168 100644 --- a/python/infinilm/cache/cache.py +++ b/python/infinilm/cache/cache.py @@ -16,11 +16,11 @@ def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0): class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig): def __init__( self, - max_kv_memory_bytes: int, + num_blocks: int, block_size: int = 16, ): _infinilm.PagedKVCacheConfig.__init__( self, - max_kv_memory_bytes, + num_blocks, block_size, ) diff --git a/python/infinilm/generation/utils.py b/python/infinilm/generation/utils.py index 00143231..36f54cc6 100644 --- a/python/infinilm/generation/utils.py +++ b/python/infinilm/generation/utils.py @@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype): return ctypes.c_int32 elif infini_dtype == infinicore.float32: return ctypes.c_float + elif infini_dtype == infinicore.int64: + return ctypes.c_int64 else: raise ValueError(f"Unsupported py_dtype: {infini_dtype}") diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 4be1008f..1a3e9255 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -4,7 +4,7 @@ import infinicore from infinilm.auto_config import AutoConfig -from infinilm.cache import StaticKVCacheConfig +from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig from infinilm.distributed import DistConfig from infinilm.lib import _infinilm @@ -18,6 +18,7 @@ class GenerationConfig: top_p: float = 1.0 eos_token_id: list[int] | None = None + stop_on_eos: bool = True class InferEngine(_infinilm.InferEngine): @@ -42,6 +43,8 @@ def __init__( self.use_cache = False + self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig) + def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -50,8 +53,8 @@ def forward( input_ids, *, position_ids=None, - cache_lengths=None, - input_lengths=None, + past_kv_lengths=None, + total_kv_lengths=None, input_offsets=None, block_tables=None, slot_mapping=None, @@ -62,8 +65,12 @@ def forward( # TODO: Remove `_underlying` and simplify the corresponding code. input_ids = input_ids._underlying if input_ids is not None else None position_ids = position_ids._underlying if position_ids is not None else None - cache_lengths = cache_lengths._underlying if cache_lengths is not None else None - input_lengths = input_lengths._underlying if input_lengths is not None else None + past_kv_lengths = ( + past_kv_lengths._underlying if past_kv_lengths is not None else None + ) + total_kv_lengths = ( + total_kv_lengths._underlying if past_kv_lengths is not None else None + ) input_offsets = input_offsets._underlying if input_offsets is not None else None block_tables = block_tables._underlying if block_tables is not None else None slot_mapping = slot_mapping._underlying if slot_mapping is not None else None @@ -74,8 +81,8 @@ def forward( super().Input( input_ids, position_ids=position_ids, - cache_lengths=cache_lengths, - input_lengths=input_lengths, + past_sequence_lengths=past_kv_lengths, + total_sequence_lengths=total_kv_lengths, input_offsets=input_offsets, block_tables=block_tables, slot_mapping=slot_mapping, @@ -87,21 +94,24 @@ def forward( .output_ids ) - def generate(self, input_ids, generation_config, *, _measure_and_log_time=False): + def generate( + self, + input_ids, + generation_config, + *, + _measure_and_log_time=False, + paged_block_size=16, + ): if generation_config.eos_token_id is None: eos_token_id = self.config.eos_token_id else: eos_token_id = generation_config.eos_token_id - # TODO: Remove the `to_numpy` calls and simplify the corresponding code. - batch_size, seq_len = input_ids.shape[:2] - - position_ids = infinicore.from_list( - [list(range(0, seq_len)) for _ in range(batch_size)], dtype=infinicore.int64 - ) - cache_lengths = infinicore.from_list([0], dtype=infinicore.int64) - + past_seq_len = 0 output_ids = [] + initial_batch_size, initial_seqlen = input_ids.shape[:2] + seq_len = initial_seqlen + batch_size = initial_batch_size if batch_size != 1 and generation_config.max_new_tokens is None: raise ValueError( @@ -111,14 +121,75 @@ def generate(self, input_ids, generation_config, *, _measure_and_log_time=False) if _measure_and_log_time: time_measurements = [] - for _ in range(0, generation_config.max_new_tokens): + for iter in range(0, generation_config.max_new_tokens): if _measure_and_log_time: start_time = time.perf_counter() + batch_size, seq_len = input_ids.shape[:2] + + if self.enable_paged_attn: + input_ids = input_ids.view([1, batch_size * seq_len]) + position_ids = infinicore.from_list( + list(range(past_seq_len, past_seq_len + seq_len)) * batch_size, + dtype=infinicore.int64, + ) + block_tables_list = [ + [ + i * batch_size + b + for i in range( + (past_seq_len + seq_len + paged_block_size - 1) + // paged_block_size + ) + ] + for b in range(batch_size) + ] + slot_mapping_list = [ + (((past_seq_len + i) // paged_block_size) * batch_size + b) + * paged_block_size + + (past_seq_len + i) % paged_block_size + for b in range(batch_size) + for i in range(seq_len) + ] + + block_tables = infinicore.from_list( + block_tables_list, + dtype=infinicore.int64, + ) + slot_mapping = infinicore.from_list( + slot_mapping_list, + dtype=infinicore.int64, + ) + else: + position_ids = infinicore.from_list( + [ + list(range(past_seq_len, past_seq_len + seq_len)) + for _ in range(batch_size) + ], + dtype=infinicore.int64, + ) + + block_tables = None + slot_mapping = None + + past_kv_lengths = infinicore.from_list( + [past_seq_len] * batch_size, dtype=infinicore.int64 + ) + total_kv_lengths = infinicore.from_list( + [past_seq_len + seq_len] * batch_size, dtype=infinicore.int64 + ) + + input_offsets = infinicore.from_list( + [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64 + ) + output_id = self( - input_ids, + input_ids=input_ids, position_ids=position_ids, - cache_lengths=cache_lengths, + past_kv_lengths=past_kv_lengths, + total_kv_lengths=total_kv_lengths, + input_offsets=input_offsets, + block_tables=block_tables, + slot_mapping=slot_mapping, temperature=generation_config.temperature, top_k=generation_config.top_k, top_p=generation_config.top_p, @@ -127,24 +198,17 @@ def generate(self, input_ids, generation_config, *, _measure_and_log_time=False) output_ids.append(output_id) if ( - generation_config.max_new_tokens is not None + initial_batch_size == 1 + and generation_config.stop_on_eos + and generation_config.max_new_tokens is not None and output_id.to_numpy()[0] in eos_token_id ): break - seq_len = position_ids.shape[-1] - input_ids = infinicore.from_list( [[output_id] for output_id in output_id.to_numpy().tolist()] ) - position_ids = infinicore.from_list( - [1 for _ in range(batch_size)], - dtype=position_ids.dtype, - device=position_ids.device, - ).view((batch_size, 1)) + position_ids.narrow(1, seq_len - 1, 1) - cache_lengths += infinicore.from_list( - [seq_len], dtype=cache_lengths.dtype, device=cache_lengths.device - ) + past_seq_len = past_seq_len + seq_len if _measure_and_log_time: end_time = time.perf_counter() @@ -156,23 +220,21 @@ def generate(self, input_ids, generation_config, *, _measure_and_log_time=False) f"\n\n\n Generation completed in {round(sum(time_measurements) * 1000, 2)} ms" ) print( - f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_measurements)}\n" + f" Batchsize={initial_batch_size} Per_Batch_Input_Len={initial_seqlen} Per_Batch_New_Tokens={len(time_measurements)}\n" ) print( - f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((batch_size * seq_len) / time_measurements[0], 2)}tok/s\n", + f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((initial_batch_size * initial_seqlen) / time_measurements[0], 2)}tok/s\n", ) if len(time_measurements) > 1: print( - f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n", + f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((initial_batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n", ) return output_ids - def reset_cache(self, batch_size: int, initial_capacity: int = 1024): + def reset_cache(self, cache_config): infinicore.sync_device() - - cache_config = StaticKVCacheConfig(batch_size, initial_capacity) - + self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig) super().reset_cache(cache_config) def state_dict_keyname(self):