Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions csrc/cache/kv_cache.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "kv_cache.hpp"

#include "../utils.hpp"

#include "infinicore/ops.hpp"
#include <stdexcept>

namespace infinilm::cache {
Expand Down Expand Up @@ -155,6 +155,7 @@ PagedKVCache::PagedKVCache(
num_blocks_per_layer_ = config.max_kv_memory_bytes()
/ (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_)
/ block_size_
/ rank_num_layers_
/ infinicore::dsize(dtype_);
if (num_blocks_per_layer_ == 0) {
throw std::runtime_error("Not enough memory for KV cache");
Expand Down Expand Up @@ -187,11 +188,78 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping) {

auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);

infinicore::op::paged_caching_(k,
v,
k_cache_layer,
v_cache_layer,
slot_mapping);
return {k_cache_layer, v_cache_layer};
}

std::tuple<infinicore::Tensor, infinicore::Tensor>
PagedKVCache::get_paged_kv(size_t layer_idx) {
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
return {k_cache_layer, v_cache_layer};
}

std::tuple<infinicore::Tensor, infinicore::Tensor>
PagedKVCache::get_contiguous_kv(
size_t layer_idx,
const infinicore::Tensor block_tables,
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id) {
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64);

/// @todo: implement paged cache update here
auto nreq = block_tables->size(0);
auto block_tables_cpu = block_tables->to(infinicore::Device::cpu());
auto cache_lens_cpu = cache_lens->to(infinicore::Device::cpu());
auto input_offsets_cpu = input_offsets->to(infinicore::Device::cpu());
infinicore::context::syncDevice();

return {k_cache_layer, v_cache_layer};
// [num_blocks, num_rank_v_heads, block_size, v_dim]
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);

auto req = request_id;
auto cache_lens_ptr = reinterpret_cast<const int64_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int64_t *>(input_offsets_cpu->data());
int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);

auto full_k = infinicore::Tensor::empty(
{num_rank_k_heads_, (size_t)total_len, k_dim_},
k_cache_layer->dtype(), k_cache_layer->device());

auto full_v = infinicore::Tensor::empty(
{num_rank_v_heads_, (size_t)total_len, v_dim_},
v_cache_layer->dtype(), v_cache_layer->device());

size_t nblocks = total_len / block_size_;
size_t r = total_len % block_size_;

for (size_t b = 0; b < nblocks; b++) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));

full_k->narrow({{1, b * block_size_, block_size_}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
full_v->narrow({{1, b * block_size_, block_size_}})
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
}

if (r > 0) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));

full_k->narrow({{1, nblocks * block_size_, r}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
full_v->narrow({{1, nblocks * block_size_, r}})
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
}

return {full_k, full_v};
}

} // namespace infinilm::cache
38 changes: 36 additions & 2 deletions csrc/cache/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class PagedKVCache final : public Cache {
/**
* @brief Update Paged KV cache at a given layer given slot info for each token.
*
* @param layer_idx Which transformer layer
* @param layer_idx Which paged attention layer
* @param k [num_rank_k_heads, seq_len, k_dim]
* @param v [num_rank_v_heads, seq_len, v_dim]
* @param slot_mapping [seq_len]
Expand All @@ -128,7 +128,41 @@ class PagedKVCache final : public Cache {
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping);

~PagedKVCache() override = default;
/**
* @brief Get Paged KV cache at a given layer.
*
* @param layer_idx Which paged attention layer
*
* @return (full_k, full_v)
* full_k: [num_blocks, num_rank_k_heads, block_size, k_dim]
* full_v: [num_blocks, num_rank_v_heads, block_size, v_dim]
*/
std::tuple<infinicore::Tensor, infinicore::Tensor>
get_paged_kv(size_t layer_idx);

/**
* @brief Get contiguous KV cache at a given layer, given the request info
* among a continuous request batch.
*
* @param layer_idx Which paged attention layer
* @param block_tables [num_requests, max_blocks_per_request]
* @param cache_lens [num_requests]
* @param input_offsets [num_requests + 1]
* @param request_id Which request among a continuous batch of requests
*
* @return (full_k, full_v)
* full_k: [num_rank_k_heads, total_len, k_dim]
* full_v: [num_rank_v_heads, total_len, v_dim]
*/
std::tuple<infinicore::Tensor, infinicore::Tensor>
get_contiguous_kv(size_t layer_idx,
const infinicore::Tensor block_tables,
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id = 0);

~PagedKVCache() override
= default;

private:
infinicore::Size k_dim_;
Expand Down
46 changes: 44 additions & 2 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,50 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------
// forward
//------------------------------------------------------
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const {
return {input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping};
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input(infinicore::Device device) const {

std::optional<infinicore::Tensor> position_ids_on_device;
if (position_ids.has_value()) {
position_ids_on_device = position_ids.value()->to(device);
}

std::optional<infinicore::Tensor> cache_lengths_on_device;
if (cache_lengths.has_value()) {
if (block_tables.has_value()) {
cache_lengths_on_device = cache_lengths.value()->to(device);
} else { // @todo: only paged kv cache support device tensor so far
cache_lengths_on_device = cache_lengths.value();
}
}

std::optional<infinicore::Tensor> input_lengths_on_device;
if (input_lengths.has_value()) {
input_lengths_on_device = input_lengths.value()->to(device);
}

std::optional<infinicore::Tensor> input_offsets_on_device;
if (input_offsets.has_value()) {
input_offsets_on_device = input_offsets.value()->to(device);
}

std::optional<infinicore::Tensor> block_tables_on_device;
if (block_tables.has_value()) {
block_tables_on_device = block_tables.value()->to(device);
}

std::optional<infinicore::Tensor> slot_mapping_on_device;
if (slot_mapping.has_value()) {
slot_mapping_on_device = slot_mapping.value()->to(device);
}

return {
input_ids, // @todo: on device in the future
position_ids_on_device,
cache_lengths_on_device,
input_lengths_on_device,
input_offsets_on_device,
block_tables_on_device,
slot_mapping_on_device};
}

InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
Expand Down
15 changes: 10 additions & 5 deletions csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void RankWorker::thread_loop() {
local_param_name = pending_param_name_;
local_param = pending_param_;
} else if (local_cmd == Command::RUN) {
local_args = pending_args_.to_model_input();
local_args = pending_args_.to_model_input(rank_info_.device);
} else if (local_cmd == Command::RESET_CACHE) {
if (pending_cache_config_ != nullptr) {
local_cache_config = pending_cache_config_->unique_copy();
Expand Down Expand Up @@ -254,13 +254,18 @@ void RankWorker::thread_loop() {
auto random_val{pending_args_.random_val};

const auto &logits_shape{logits->shape()};
const auto &batch_size{logits_shape[0]};
const auto &vocab_size{logits_shape[2]};
const auto &total_len{logits_shape[1]};
const auto &batch_size{logits_shape[0]};

auto n_req = pending_args_.input_offsets.value()->size(0);
int64_t *input_lengths = (int64_t *)pending_args_.input_lengths.value()->data();
int64_t *input_offsets = (int64_t *)pending_args_.input_offsets.value()->data();

auto output_ids{infinicore::Tensor::empty({batch_size}, infinicore::DataType::I32, rank_info_.device)};
auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};

for (auto i{decltype(batch_size)(0)}; i < batch_size; ++i) {
auto score{logits->narrow({{0, i, 1}})->view({vocab_size})};
for (auto i{decltype(n_req)(0)}; i < n_req; ++i) {
auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i] + input_lengths[i] - 1), 1}})->view({vocab_size})};
auto out{output_ids->narrow({{0, i, 1}})->view({})};
infinicore::op::random_sample_(
out, score, random_val, top_p, top_k, temperature);
Expand Down
2 changes: 1 addition & 1 deletion csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class RankWorker {

float random_val{0.1};

infinilm::InfinilmModel::Input to_model_input() const;
infinilm::InfinilmModel::Input to_model_input(infinicore::Device device) const;
};

struct Output {
Expand Down
Loading