diff --git a/CMakeLists.txt b/CMakeLists.txt index 645ce0a2..51074963 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,6 +3,9 @@ set_property(GLOBAL PROPERTY USE_FOLDERS ON) option(USE_NPU "Enable NPU support" OFF) option(USE_MLU "Enable MLU support" OFF) +add_compile_definitions(YLT_ENABLE_IBV) +add_definitions(-DYLT_ENABLE_IBV) +set(YLT_ENABLE_IBV ON) if(DEVICE_ARCH STREQUAL "ARM") set(CMAKE_SYSTEM_PROCESSOR aarch64) diff --git a/setup.py b/setup.py index 5b43e398..44809388 100644 --- a/setup.py +++ b/setup.py @@ -445,7 +445,7 @@ def check_and_install_pre_commit(): print("Run 'pre-commit install' failed. Please install pre-commit: pip install pre-commit") exit(0) -def run_git_command(command, cwd=None, check=True): +def run_shell_command(command, cwd=None, check=True): try: subprocess.run(command, cwd=cwd, check=check, shell=True, capture_output=True, text=True) return True @@ -492,15 +492,15 @@ def apply_patch_safely(patch_file_path, repo_path): if has_uncommitted_changes(repo_path): print(f"⚠️ Uncommitted changes detected. Running `git reset --hard` for {repo_path}") - if not run_git_command("git reset --hard", cwd=repo_path): + if not run_shell_command("git reset --hard", cwd=repo_path): print("❌ Failed to reset changes!") return False print(f"🛠️ Apply patch: {patch_file_path}") - apply_success = run_git_command(f"git apply --check {patch_file_path}", cwd=repo_path, check=False) + apply_success = run_shell_command(f"git apply --check {patch_file_path}", cwd=repo_path, check=False) if apply_success: - if not run_git_command(f"git apply {patch_file_path}", cwd=repo_path): + if not run_shell_command(f"git apply {patch_file_path}", cwd=repo_path): print("❌ apply patch fail!") apply_success = False @@ -512,7 +512,7 @@ def apply_patch_safely(patch_file_path, repo_path): print(f" cd {repo_path} && git apply {patch_file_path}") return False -def apply_patch(): +def pre_build(): if os.path.exists("third_party/custom_patch"): script_path = os.path.dirname(os.path.abspath(__file__)) mooncake_repo_path = os.path.join(script_path, "third_party/Mooncake") @@ -521,6 +521,9 @@ def apply_patch(): cpprestsdk_repo_path = os.path.join(script_path, "third_party/cpprestsdk") if not apply_patch_safely("../custom_patch/cpprestsdk.patch", cpprestsdk_repo_path): exit(0) + if not run_shell_command("sh third_party/dependencies.sh", cwd=script_path): + print("❌ Failed to reset changes!") + exit(0) if __name__ == "__main__": device = 'a2' # default @@ -537,9 +540,10 @@ def apply_patch(): del sys.argv[idx] del sys.argv[idx] if '--dry_run' not in sys.argv: - apply_patch() + pre_build() else: sys.argv.remove("--dry_run") + if '--install-xllm-kernels' in sys.argv: idx = sys.argv.index('--install-xllm-kernels') if idx + 1 < len(sys.argv): diff --git a/third_party/Mooncake b/third_party/Mooncake index fb26af76..be894977 160000 --- a/third_party/Mooncake +++ b/third_party/Mooncake @@ -1 +1 @@ -Subproject commit fb26af7613d4251c9c006c9ff7eef5ff4e18ed65 +Subproject commit be894977d926c5fff03735e8aa37e93aaaf041bc diff --git a/third_party/custom_patch/Mooncake.patch b/third_party/custom_patch/Mooncake.patch index 85a8f4b9..51a0845e 100644 --- a/third_party/custom_patch/Mooncake.patch +++ b/third_party/custom_patch/Mooncake.patch @@ -26,7 +26,7 @@ index 8483085..9d263dd 100644 add_subdirectory(benchmarks) \ No newline at end of file diff --git a/mooncake-store/include/offset_allocator/offset_allocator.hpp b/mooncake-store/include/offset_allocator/offset_allocator.hpp -index fde978b..ac54f8c 100644 +index b6d55c8..2d80158 100644 --- a/mooncake-store/include/offset_allocator/offset_allocator.hpp +++ b/mooncake-store/include/offset_allocator/offset_allocator.hpp @@ -6,7 +6,7 @@ @@ -38,32 +38,16 @@ index fde978b..ac54f8c 100644 namespace mooncake::offset_allocator { typedef unsigned char uint8; -diff --git a/mooncake-store/include/storage_backend.h b/mooncake-store/include/storage_backend.h -index a30f0c6..3bfeffb 100644 ---- a/mooncake-store/include/storage_backend.h -+++ b/mooncake-store/include/storage_backend.h -@@ -4,8 +4,8 @@ - #include - #include - #include --#include --#include -+#include "types.h" -+#include "file_interface.h" - #include - #include - #include diff --git a/mooncake-store/include/types.h b/mooncake-store/include/types.h -index d2830a3..39c3b27 100644 +index 077926b..b36862b 100644 --- a/mooncake-store/include/types.h +++ b/mooncake-store/include/types.h -@@ -10,8 +10,11 @@ - #include +@@ -8,7 +8,10 @@ + #include #include -#include "Slab.h" +#include "cachelib_memory_allocator/Slab.h" - #include "allocator.h" +namespace iguana { +using std::contiguous_iterator; +} @@ -71,7 +55,7 @@ index d2830a3..39c3b27 100644 #include "ylt/struct_json/json_writer.h" diff --git a/mooncake-store/src/ha_helper.cpp b/mooncake-store/src/ha_helper.cpp -index c1f4ded..e4b982b 100644 +index 796838a..8e72b80 100644 --- a/mooncake-store/src/ha_helper.cpp +++ b/mooncake-store/src/ha_helper.cpp @@ -1,3 +1,6 @@ @@ -81,7 +65,7 @@ index c1f4ded..e4b982b 100644 #include "ha_helper.h" #include "etcd_helper.h" #include "rpc_service.h" -@@ -169,4 +172,6 @@ MasterServiceSupervisor::~MasterServiceSupervisor() { +@@ -174,4 +177,6 @@ MasterServiceSupervisor::~MasterServiceSupervisor() { } } @@ -113,7 +97,7 @@ index f515671..7a019b6 100644 +#pragma GCC pop_options \ No newline at end of file diff --git a/mooncake-store/src/utils.cpp b/mooncake-store/src/utils.cpp -index b775851..736d61d 100644 +index 9678f57..f41eb10 100644 --- a/mooncake-store/src/utils.cpp +++ b/mooncake-store/src/utils.cpp @@ -1,6 +1,6 @@ diff --git a/third_party/dependencies.sh b/third_party/dependencies.sh new file mode 100644 index 00000000..ba3c1ffc --- /dev/null +++ b/third_party/dependencies.sh @@ -0,0 +1,96 @@ +#!/bin/bash + +# Color definitions +GREEN="\033[0;32m" +BLUE="\033[0;34m" +YELLOW="\033[0;33m" +RED="\033[0;31m" +NC="\033[0m" # No Color + +# Configuration +REPO_ROOT=`pwd` + +# Function to print section headers +print_section() { + echo -e "\n${BLUE}=== $1 ===${NC}" +} + +# Function to print success messages +print_success() { + echo -e "${GREEN}✓ $1${NC}" +} + +# Function to print error messages and exit +print_error() { + echo -e "${RED}✗ ERROR: $1${NC}" + exit 1 +} + +# Function to check command success +check_success() { + if [ $? -ne 0 ]; then + print_error "$1" + fi +} + +if [ $(id -u) -ne 0 ]; then + print_error "Require root permission, try sudo ./dependencies.sh" +fi + + +# Install yalantinglibs +print_section "Installing yalantinglibs" + +# Check if thirdparties directory exists +if [ ! -d "${REPO_ROOT}/third_party/Mooncake/thirdparties" ]; then + mkdir -p "${REPO_ROOT}/third_party/Mooncake/thirdparties" + check_success "Failed to create Mooncake/thirdparties directory" +fi + +# Change to thirdparties directory +cd "${REPO_ROOT}/third_party/Mooncake/thirdparties" +check_success "Failed to change to Mooncake/thirdparties directory" + +# Check if yalantinglibs is already installed +if [ -d "yalantinglibs" ]; then + echo -e "${YELLOW}yalantinglibs directory already exists. Removing for fresh install...${NC}" + rm -rf yalantinglibs + check_success "Failed to remove existing yalantinglibs directory" +fi + +# Clone yalantinglibs +echo "Cloning yalantinglibs from https://github.com/alibaba/yalantinglibs.git" +git clone https://github.com/alibaba/yalantinglibs.git +check_success "Failed to clone yalantinglibs" + +# Build and install yalantinglibs +cd yalantinglibs +check_success "Failed to change to yalantinglibs directory" + +# Checkout version 0.5.5 +echo "Checking out yalantinglibs version 0.5.5..." +git checkout 0.5.5 +check_success "Failed to checkout yalantinglibs version 0.5.5" + +mkdir -p build +check_success "Failed to create build directory" + +cd build +check_success "Failed to change to build directory" + +echo "Configuring yalantinglibs..." +cmake .. -DBUILD_EXAMPLES=OFF -DBUILD_BENCHMARK=OFF -DBUILD_UNIT_TESTS=OFF -DYLT_ENABLE_IBV=ON +check_success "Failed to configure yalantinglibs" + +echo "Building yalantinglibs (using $(nproc) cores)..." +cmake --build . -j$(nproc) +check_success "Failed to build yalantinglibs" + +echo "Installing yalantinglibs..." +cmake --install . +check_success "Failed to install yalantinglibs" + +sed -i '54s/target_link_libraries(${ylt_target_name} -libverbs)/target_link_libraries(${ylt_target_name} INTERFACE -libverbs)/' /usr/local/lib/cmake/yalantinglibs/config.cmake + +print_success "yalantinglibs installed successfully" + diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 9e41164e..ceed4ffa 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -332,14 +332,18 @@ DEFINE_string(store_protocol, "tcp", "KV cache store protocol(e.g. tcp, rdma)."); -DEFINE_string(store_master_server_entry, +DEFINE_string(store_master_server_address, "", "The address information of the store master service."); -DEFINE_string(store_metadata_connstring, +DEFINE_string(store_metadata_server, "", "The address of the kv cache store metadata service."); +DEFINE_string(store_local_hostname, + "", + "The local host name of the kv cache store client."); + // --- computation communication parallel config --- DEFINE_bool( diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 7fc36442..b7304e7c 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -159,9 +159,11 @@ DECLARE_bool(enable_kvcache_store); DECLARE_string(store_protocol); -DECLARE_string(store_master_server_entry); +DECLARE_string(store_master_server_address); -DECLARE_string(store_metadata_connstring); +DECLARE_string(store_metadata_server); + +DECLARE_string(store_local_hostname); DECLARE_bool(enable_multi_stream_parallel); diff --git a/xllm/core/common/options.cpp b/xllm/core/common/options.cpp index 728e742b..fa7542a9 100644 --- a/xllm/core/common/options.cpp +++ b/xllm/core/common/options.cpp @@ -53,8 +53,9 @@ std::string Options::to_string() const { << ", enable_cache_upload: " << enable_cache_upload() << ", enable_kvcache_store: " << enable_kvcache_store() << ", store_protocol: " << store_protocol() - << ", store_master_server_entry: " << store_master_server_entry() - << ", store_metadata_connstring: " << store_metadata_connstring() + << ", store_master_server_address: " << store_master_server_address() + << ", store_metadata_server: " << store_metadata_server() + << ", store_local_hostname: " << store_local_hostname() << ", enable_multi_stream_parallel: " << enable_multi_stream_parallel() << ", enable_continuous_kvcache: " << enable_continuous_kvcache() << ", disable_ttft_profiling: " << disable_ttft_profiling() diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 7815b4a4..895c1b2c 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -143,9 +143,11 @@ class Options { PROPERTY(std::string, store_protocol) = "tcp"; - PROPERTY(std::string, store_master_server_entry) = ""; + PROPERTY(std::string, store_master_server_address) = ""; - PROPERTY(std::string, store_metadata_connstring) = ""; + PROPERTY(std::string, store_metadata_server) = ""; + + PROPERTY(std::string, store_local_hostname) = ""; PROPERTY(bool, enable_multi_stream_parallel) = false; diff --git a/xllm/core/distributed_runtime/comm_channel.cpp b/xllm/core/distributed_runtime/comm_channel.cpp index 720a6b91..19eb1c87 100644 --- a/xllm/core/distributed_runtime/comm_channel.cpp +++ b/xllm/core/distributed_runtime/comm_channel.cpp @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include + namespace xllm { bool CommChannel::init_brpc(const std::string& server_address) { @@ -306,21 +308,121 @@ bool CommChannel::allocate_kv_cache_with_transfer( return true; } -bool CommChannel::load_kv_blocks_from_store_async( - const std::vector& cache_block_info, +void CommChannel::transfer_kv_blocks( + const std::vector& block_transfer_info, folly::Promise& promise) { - proto::CacheBlockInfos pb_cache_block_info; - if (!cache_block_info_to_proto(cache_block_info, &pb_cache_block_info)) { + proto::BlockTransferInfos pb_block_transfer_info; + if (!block_transfer_info_to_proto( + 0x0, block_transfer_info, &pb_block_transfer_info)) { promise.setValue(0); - return false; + return; } - auto done = new LoadKVCacheFromStoreClosure(); + auto done = new TransferBlocksClosure(); done->promise = std::move(promise); - stub_->LoadKVCacheFromStore( - &done->cntl, &pb_cache_block_info, &done->response, done); + stub_->TransferBlocks( + &done->cntl, &pb_block_transfer_info, &done->response, done); +} - return true; +void CommChannel::transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info) { + proto::BlockTransferInfos pb_block_transfer_info; + if (!block_transfer_info_to_proto( + batch_id, block_transfer_info, &pb_block_transfer_info)) { + return; + } + brpc::Controller cntl; + proto::TransferStatus response; + stub_->TransferBlocks(&cntl, &pb_block_transfer_info, &response, nullptr); +} + +class ClientStreamReceiver : public brpc::StreamInputHandler { + private: + const std::atomic& termination_flag_; + std::shared_ptr> success_cnt_; + std::promise close_promise_; + std::atomic promise_set_{false}; + + public: + ClientStreamReceiver(const std::atomic& termination_flag, + std::shared_ptr>& success_cnt) + : termination_flag_(termination_flag), success_cnt_(success_cnt) {} + + ~ClientStreamReceiver() { + if (!promise_set_.exchange(true)) { + try { + close_promise_.set_value(); + } catch (const std::exception& e) { + LOG(WARNING) << "Exception in destructor: " << e.what(); + } + } + } + + std::future get_close_future() { return close_promise_.get_future(); } + + int on_received_messages(brpc::StreamId id, + butil::IOBuf* const messages[], + size_t size) override { + for (size_t i = 0; i < size; ++i) { + std::string msg_str = messages[i]->to_string(); + int32_t success_cnt = std::stoi(msg_str); + + if (success_cnt > 0 && + !termination_flag_.load(std::memory_order_acquire)) { + success_cnt_->fetch_add(success_cnt, std::memory_order_relaxed); + } else { + brpc::StreamClose(id); + if (!promise_set_.exchange(true)) { + close_promise_.set_value(); + } + break; + } + } + return 0; + } + + virtual void on_idle_timeout(brpc::StreamId id) override { + if (!promise_set_.exchange(true)) { + close_promise_.set_value(); + } + } + + virtual void on_closed(brpc::StreamId id) override { + if (!promise_set_.exchange(true)) { + close_promise_.set_value(); + } + } +}; + +void CommChannel::prefetch_from_storage( + const std::atomic& flag, + const std::vector& block_transfer_info, + std::shared_ptr>& success_cnt) { + proto::BlockTransferInfos pb_block_transfer_info; + if (!block_transfer_info_to_proto( + 0x0, block_transfer_info, &pb_block_transfer_info)) { + return; + } + ClientStreamReceiver receiver(flag, success_cnt); + brpc::Controller cntl; + brpc::StreamOptions stream_options; + brpc::StreamId stream_id; + proto::Status response; + stream_options.handler = &receiver; + if (brpc::StreamCreate(&stream_id, cntl, &stream_options) != 0) { + LOG(ERROR) << "Failed to create stream"; + return; + } + + stub_->PrefetchFromStorage( + &cntl, &pb_block_transfer_info, &response, nullptr); + + if (cntl.Failed()) { + LOG(ERROR) << "Fail to connect stream, " << cntl.ErrorText(); + } + + receiver.get_close_future().wait(); } bool CommChannel::get_last_step_result_async( @@ -397,18 +499,6 @@ bool CommChannel::execute_model_with_brpc( return true; } -void LoadKVCacheFromStoreClosure::Run() { - std::unique_ptr self_guard(this); - - bool success = !cntl.Failed(); - if (!success) { - promise.setValue(0); - } else { - promise.setValue(response.success_cnt()); - } - return; -} - void ExecuteModelClosure::Run() { std::unique_ptr self_guard(this); @@ -437,4 +527,17 @@ void InitModelClosure::Run() { return; } + +void TransferBlocksClosure::Run() { + std::unique_ptr self_guard(this); + + bool success = !cntl.Failed(); + if (!success) { + promise.setValue(0); + } else { + promise.setValue(response.success_cnt()); + } + return; +} + } // namespace xllm \ No newline at end of file diff --git a/xllm/core/distributed_runtime/comm_channel.h b/xllm/core/distributed_runtime/comm_channel.h index a97850c0..975f4bda 100644 --- a/xllm/core/distributed_runtime/comm_channel.h +++ b/xllm/core/distributed_runtime/comm_channel.h @@ -87,10 +87,19 @@ class CommChannel { const uint64_t kv_cache_size, const std::vector>& kv_cache_shape); - virtual bool load_kv_blocks_from_store_async( - const std::vector& cache_block_info, + virtual void transfer_kv_blocks( + const std::vector& block_transfer_info, folly::Promise& promise); + virtual void transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info); + + virtual void prefetch_from_storage( + const std::atomic& flag, + const std::vector& block_transfer_info, + std::shared_ptr>& success_cnt); + virtual bool get_last_step_result_async( folly::Promise>& promise); @@ -128,11 +137,11 @@ class ExecuteModelClosure : public google::protobuf::Closure { folly::Promise> promise; }; -class LoadKVCacheFromStoreClosure : public google::protobuf::Closure { +class TransferBlocksClosure : public google::protobuf::Closure { public: void Run(); - proto::StoreResponse response; + proto::TransferStatus response; brpc::Controller cntl; folly::Promise promise; }; diff --git a/xllm/core/distributed_runtime/remote_worker.cpp b/xllm/core/distributed_runtime/remote_worker.cpp index edf3b8be..fc244cf8 100644 --- a/xllm/core/distributed_runtime/remote_worker.cpp +++ b/xllm/core/distributed_runtime/remote_worker.cpp @@ -35,6 +35,7 @@ limitations under the License. #include "util/hash_util.h" namespace xllm { + RemoteWorker::RemoteWorker(int32_t global_rank, const std::string& server_address, const torch::Device& d, @@ -282,18 +283,43 @@ folly::SemiFuture RemoteWorker::pull_kv_blocks_async( return future; } -folly::SemiFuture RemoteWorker::load_kv_blocks_from_store_async( - const std::vector cache_block_info) { +folly::SemiFuture RemoteWorker::transfer_kv_blocks( + const std::vector& block_transfer_info) { folly::Promise promise; auto future = promise.getSemiFuture(); - general_threadpool_.schedule([this, - cache_block_info = std::move(cache_block_info), - promise = std::move(promise)]() mutable { - channel_->load_kv_blocks_from_store_async(cache_block_info, promise); - }); + copy_threadpool_.schedule( + [this, + block_transfer_info = std::move(block_transfer_info), + promise = std::move(promise)]() mutable { + channel_->transfer_kv_blocks(block_transfer_info, promise); + }); return future; } +void RemoteWorker::transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info) { + copy_threadpool_.schedule( + [this, + batch_id = batch_id, + block_transfer_info = std::move(block_transfer_info)]() mutable { + channel_->transfer_kv_blocks(batch_id, block_transfer_info); + }); +} + +void RemoteWorker::prefetch_from_storage( + const std::atomic& flag, + const std::vector& block_transfer_info, + std::shared_ptr>& success_cnt) { + copy_threadpool_.schedule( + [this, + flag = &flag, + block_transfer_info = std::move(block_transfer_info), + success_cnt = success_cnt]() mutable { + channel_->prefetch_from_storage(flag, block_transfer_info, success_cnt); + }); +} + const torch::Device& RemoteWorker::device() const { LOG(ERROR) << "RemoteWorker Method device is UnImplemented."; } diff --git a/xllm/core/distributed_runtime/remote_worker.h b/xllm/core/distributed_runtime/remote_worker.h index b6b4601e..478f9e73 100644 --- a/xllm/core/distributed_runtime/remote_worker.h +++ b/xllm/core/distributed_runtime/remote_worker.h @@ -110,8 +110,17 @@ class RemoteWorker : public WorkerClient { const std::vector& src_blocks, const std::vector& dst_blocks); - virtual folly::SemiFuture load_kv_blocks_from_store_async( - const std::vector cache_block_info); + virtual folly::SemiFuture transfer_kv_blocks( + const std::vector& block_transfer_info) override; + + virtual void transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info) override; + + virtual void prefetch_from_storage( + const std::atomic& flag, + const std::vector& block_transfer_info, + std::shared_ptr>& success_cnt) override; // Run the model and return the output. virtual folly::SemiFuture> step_async( @@ -140,9 +149,8 @@ class RemoteWorker : public WorkerClient { // connection resource std::unique_ptr channel_; ThreadPool threadpool_; - // general working thread - // do some overlap work with model execute - ThreadPool general_threadpool_{4}; + // copy working thread + ThreadPool copy_threadpool_{4}; const torch::Device device_; }; } // namespace xllm diff --git a/xllm/core/distributed_runtime/worker_service.cpp b/xllm/core/distributed_runtime/worker_service.cpp index b9c1fd74..6982478c 100644 --- a/xllm/core/distributed_runtime/worker_service.cpp +++ b/xllm/core/distributed_runtime/worker_service.cpp @@ -417,18 +417,128 @@ void WorkerService::PullKVCache(::google::protobuf::RpcController* controller, return; } -void WorkerService::LoadKVCacheFromStore( +void WorkerService::TransferBlocks( ::google::protobuf::RpcController* controller, - const ::xllm::proto::CacheBlockInfos* req, - ::xllm::proto::StoreResponse* resp, + const proto::BlockTransferInfos* req, + proto::TransferStatus* resp, ::google::protobuf::Closure* done) { brpc::ClosureGuard done_guard(done); - std::vector dst_blocks; - proto_to_cache_block_info(*req, dst_blocks); + std::vector block_transfer_info; + uint64_t batch_id = proto_to_block_transfer_info(*req, block_transfer_info); - auto future = worker_->load_kv_blocks_from_store_async(dst_blocks); + if (batch_id == 0x0) { + resp->set_success_cnt(worker_->transfer_kv_blocks(block_transfer_info)); + } else { + worker_->transfer_kv_blocks(batch_id, std::move(block_transfer_info)); + } + return; +} + +class ServerStreamHandler : public brpc::StreamInputHandler { + private: + std::promise close_promise_; + std::atomic promise_set_{false}; + + public: + ~ServerStreamHandler() { + if (!promise_set_.exchange(true)) { + try { + close_promise_.set_value(); + } catch (const std::exception& e) { + LOG(WARNING) << "Exception in destructor: " << e.what(); + } + } + } + + std::future get_close_future() { return close_promise_.get_future(); } + + int on_received_messages(brpc::StreamId id, + butil::IOBuf* const messages[], + size_t size) override { + LOG(WARNING) << "ServerStreamHandler::on_received_messages not implement."; + return 0; + } + + void on_closed(brpc::StreamId id) override { + if (!promise_set_.exchange(true)) { + close_promise_.set_value(); + } + } + + void on_idle_timeout(brpc::StreamId id) override { + if (!promise_set_.exchange(true)) { + LOG(WARNING) << "Stream idle timeout: " << id; + close_promise_.set_value(); + } + } +}; + +void WorkerService::PrefetchFromStorage( + google::protobuf::RpcController* controller, + const proto::BlockTransferInfos* req, + proto::Status* resp, + google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + brpc::Controller* cntl = static_cast(controller); + + auto stream_handler = std::make_unique(); + auto stream_id = std::make_unique(); + brpc::StreamOptions stream_options; + stream_options.handler = stream_handler.get(); + if (brpc::StreamAccept(stream_id.get(), *cntl, &stream_options) != 0) { + resp->set_ok(false); + LOG(ERROR) << "Failed to accept stream!"; + return; + } + + std::vector block_transfer_info; + proto_to_block_transfer_info(*req, block_transfer_info); + + copy_threadpool_.schedule( + [this, + block_transfer_info = std::move(block_transfer_info), + stream_id = std::move(stream_id), + stream_handler = std::move(stream_handler)]() mutable { + Slice transfer_slice{block_transfer_info}; + auto close_future = stream_handler->get_close_future(); + bool is_completed = false; + + for (size_t i = 0; i < transfer_slice.size(); + i += stream_copy_batch_size_) { + auto current_slice = transfer_slice.slice( + i, std::min(i + stream_copy_batch_size_, transfer_slice.size())); + + auto success_cnt = worker_->prefetch_from_storage(current_slice); + + if (success_cnt != current_slice.size() || + i + stream_copy_batch_size_ >= transfer_slice.size()) { + is_completed = true; + } + + butil::IOBuf buf; + buf.append(std::to_string(success_cnt)); + if (brpc::StreamWrite(*stream_id.get(), buf) != 0) { + brpc::StreamClose(*stream_id.get()); + is_completed = false; + break; + } + + if (is_completed) { + if (success_cnt != 0) { + butil::IOBuf buf_end; + buf_end.append("0"); + brpc::StreamWrite(*stream_id.get(), buf_end); + } + break; + } + } + if (is_completed) { + close_future.wait(); + } + brpc::StreamClose(*stream_id.get()); + }); - resp->set_success_cnt(std::move(future).get()); + resp->set_ok(true); return; } diff --git a/xllm/core/distributed_runtime/worker_service.h b/xllm/core/distributed_runtime/worker_service.h index efc560a9..2f4d6c63 100644 --- a/xllm/core/distributed_runtime/worker_service.h +++ b/xllm/core/distributed_runtime/worker_service.h @@ -80,11 +80,15 @@ class WorkerService : public proto::DistributeWorker { proto::Status* resp, ::google::protobuf::Closure* done) override; - virtual void LoadKVCacheFromStore( - ::google::protobuf::RpcController* controller, - const ::xllm::proto::CacheBlockInfos* req, - ::xllm::proto::StoreResponse* resp, - ::google::protobuf::Closure* done) override; + void TransferBlocks(::google::protobuf::RpcController* controller, + const proto::BlockTransferInfos* req, + proto::TransferStatus* resp, + ::google::protobuf::Closure* done) override; + + void PrefetchFromStorage(google::protobuf::RpcController* controller, + const proto::BlockTransferInfos* req, + proto::Status* resp, + google::protobuf::Closure* done) override; void GetDeviceInfo(::google::protobuf::RpcController* controller, const proto::Empty* req, @@ -150,6 +154,9 @@ class WorkerService : public proto::DistributeWorker { std::unique_ptr polling_thread_; std::unique_ptr threadpool_; + ThreadPool copy_threadpool_{5}; + + uint32_t stream_copy_batch_size_ = 2; }; } // namespace xllm diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index d2de8049..a73606be 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -73,9 +73,8 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, allowed_max_tokens_, input_embeddings_vec_, mm_data_vec_, - copy_in_cache_block_infos_, - copy_out_cache_block_infos_, - swap_cache_block_infos_, + swap_block_transfer_infos_, + batch_id_, &args); return builder.build_forward_input(num_decoding_tokens, min_decoding_batch_size); @@ -88,9 +87,8 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx, allowed_max_tokens_, input_embeddings_vec_, mm_data_vec_, - copy_in_cache_block_infos_, - copy_out_cache_block_infos_, - swap_cache_block_infos_, + swap_block_transfer_infos_, + batch_id_, nullptr, thread_pool); return builder.build_raw_forward_input(start_idx, end_idx); diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index f862b305..3fe3ed1d 100644 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -16,6 +16,8 @@ limitations under the License. #pragma once +#include +#include #include #include @@ -48,20 +50,18 @@ class Batch { sequence_groups_.push_back(sequence_group); } - void set_copy_in_cache_block_infos( - std::vector* copy_in_cache_block_infos) { - copy_in_cache_block_infos_ = copy_in_cache_block_infos; + void set_swap_block_transfer_infos( + std::vector* swap_block_transfer_infos) { + swap_block_transfer_infos_ = swap_block_transfer_infos; } - void set_copy_out_cache_block_infos( - std::vector* copy_out_cache_block_infos) { - copy_out_cache_block_infos_ = copy_out_cache_block_infos; + void set_batch_id() { + if (batch_id_ == 0x0) { + batch_id_ = absl::ToUnixMicros(absl::Now()); + } } - void set_swap_cache_block_infos( - std::vector* swap_cache_block_infos) { - swap_cache_block_infos_ = swap_cache_block_infos; - } + uint64_t batch_id() const { return batch_id_; } // get the number of sequences in the batch size_t size() const { return sequences_.size(); } @@ -123,9 +123,7 @@ class Batch { std::vector sequences_; std::vector sequence_groups_; - std::vector* copy_in_cache_block_infos_ = nullptr; - std::vector* copy_out_cache_block_infos_ = nullptr; - std::vector* swap_cache_block_infos_ = nullptr; + std::vector* swap_block_transfer_infos_ = nullptr; // max number of tokens to process for each sequence // default to max value @@ -138,6 +136,8 @@ class Batch { // all sequences in this batch are in prefill stage bool all_seqs_in_prefill_ = false; + + uint64_t batch_id_ = 0x0; }; } // namespace xllm diff --git a/xllm/core/framework/batch/batch_factory.cpp b/xllm/core/framework/batch/batch_factory.cpp index 5dd9d428..9217dbed 100644 --- a/xllm/core/framework/batch/batch_factory.cpp +++ b/xllm/core/framework/batch/batch_factory.cpp @@ -33,9 +33,7 @@ std::vector BatchFactory::create_batches( const std::vector>& running_requests, const std::vector& running_sequences, const std::vector& running_sequences_budgets, - std::vector>* copy_in_cache_block_infos, - std::vector>* copy_out_cache_block_infos, - std::vector>* swap_cache_block_infos) { + std::vector>* swap_block_transfer_infos) { size_t num_prompt_tokens = 0; size_t num_generated_tokens = 0; std::vector batches(dp_size_); @@ -74,19 +72,10 @@ std::vector BatchFactory::create_batches( for (int i = 0; i < dp_size_; i++) { if (!batches[i].empty()) { - if (copy_in_cache_block_infos != nullptr && - copy_in_cache_block_infos->size() == dp_size_) { - batches[i].set_copy_in_cache_block_infos( - &(copy_in_cache_block_infos->at(i))); - } - if (copy_out_cache_block_infos != nullptr && - copy_out_cache_block_infos->size() == dp_size_) { - batches[i].set_copy_out_cache_block_infos( - &(copy_out_cache_block_infos->at(i))); - } - if (swap_cache_block_infos != nullptr && - swap_cache_block_infos->size() == dp_size_) { - batches[i].set_swap_cache_block_infos(&(swap_cache_block_infos->at(i))); + if (swap_block_transfer_infos != nullptr && + swap_block_transfer_infos->size() == dp_size_) { + batches[i].set_swap_block_transfer_infos( + &(swap_block_transfer_infos->at(i))); } } } diff --git a/xllm/core/framework/batch/batch_factory.h b/xllm/core/framework/batch/batch_factory.h index 44106771..bd7d7a08 100644 --- a/xllm/core/framework/batch/batch_factory.h +++ b/xllm/core/framework/batch/batch_factory.h @@ -31,14 +31,8 @@ class BatchFactory { const std::vector>& running_requests, const std::vector& running_sequences, const std::vector& running_sequences_budgets, - // for global kv cache copy block from host to device - std::vector>* copy_in_cache_block_infos = - nullptr, - // for global kv cache copy block from device to host - std::vector>* copy_out_cache_block_infos = - nullptr, // for beam-search - std::vector>* swap_cache_block_infos = + std::vector>* swap_block_transfer_infos = nullptr); private: diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 2ff34176..60ddfd13 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -37,30 +37,13 @@ limitations under the License. namespace xllm { -void split_copy_out_blocks(RawForwardInput& raw_forward_input, - std::unordered_set& write_block_ids) { - std::vector async_copy_out_blocks; - std::vector sync_copy_out_blocks; - for (CacheBlockInfo& content : raw_forward_input.copy_out_blocks) { - if (write_block_ids.find(content.device_block_id) != - write_block_ids.end()) { - sync_copy_out_blocks.emplace_back(std::move(content)); - } else { - async_copy_out_blocks.emplace_back(std::move(content)); - } - } - raw_forward_input.copy_out_blocks = std::move(sync_copy_out_blocks); - raw_forward_input.async_copy_out_blocks = std::move(async_copy_out_blocks); -} - BatchInputBuilder::BatchInputBuilder( const std::vector& sequences, const std::vector& allowed_max_tokens, const std::vector& input_embeddings_vec, const std::vector& mm_data_vec, - const std::vector* copy_in_cache_block_infos, - const std::vector* copy_out_cache_block_infos, - std::vector* swap_cache_block_infos, + std::vector* swap_block_transfer_infos, + const uint64_t batch_id, const ModelArgs* args, ThreadPool* thread_pool) : sequences_(sequences), @@ -70,9 +53,8 @@ BatchInputBuilder::BatchInputBuilder( args_(args), thread_pool_(thread_pool), num_sequences_(static_cast(sequences.size())), - copy_in_cache_block_infos_(copy_in_cache_block_infos), - copy_out_cache_block_infos_(copy_out_cache_block_infos), - swap_cache_block_infos_(swap_cache_block_infos) { + swap_block_transfer_infos_(swap_block_transfer_infos), + batch_id_(batch_id) { // Reserve space for better performance state_.flatten_tokens_vec.reserve(1000); state_.flatten_positions_vec.reserve(1000); @@ -572,11 +554,11 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.input_embedding = torch::cat(input_embeddings_vec_); } - if (swap_cache_block_infos_ != nullptr && - swap_cache_block_infos_->size() > 0) { + if (swap_block_transfer_infos_ != nullptr && + swap_block_transfer_infos_->size() > 0) { input_params.swap_blocks.insert(input_params.swap_blocks.end(), - swap_cache_block_infos_->begin(), - swap_cache_block_infos_->end()); + swap_block_transfer_infos_->begin(), + swap_block_transfer_infos_->end()); } if (FLAGS_enable_continuous_kvcache) { @@ -663,54 +645,40 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { } } - if (copy_out_cache_block_infos_ != nullptr && - copy_out_cache_block_infos_->size() > 0) { - raw_forward_input.copy_out_blocks.insert( - raw_forward_input.copy_out_blocks.end(), - copy_out_cache_block_infos_->begin(), - copy_out_cache_block_infos_->end()); - } - if (copy_in_cache_block_infos_ != nullptr && - copy_in_cache_block_infos_->size() > 0) { - raw_forward_input.copy_in_blocks.insert( - raw_forward_input.copy_in_blocks.end(), - copy_in_cache_block_infos_->begin(), - copy_in_cache_block_infos_->end()); - } - split_copy_out_blocks(raw_forward_input, write_block_ids_); process_swap_block_infos(raw_forward_input); + raw_forward_input.batch_id = batch_id_; return raw_forward_input; } void BatchInputBuilder::process_swap_block_infos( RawForwardInput& raw_forward_input) { - if (swap_cache_block_infos_ == nullptr || - swap_cache_block_infos_->size() == 0) { + if (swap_block_transfer_infos_ == nullptr || + swap_block_transfer_infos_->size() == 0) { return; } if (FLAGS_enable_block_copy_kernel) { - auto& swap_blocks = *swap_cache_block_infos_; + auto& swap_blocks = *swap_block_transfer_infos_; std::sort(swap_blocks.begin(), swap_blocks.end(), - [](const CacheBlockInfo& a, const CacheBlockInfo& b) { - return a.device_block_id < b.device_block_id; + [](const BlockTransferInfo& a, const BlockTransferInfo& b) { + return a.src_block_id < b.src_block_id; }); if (swap_blocks.size() > 0) { std::vector src_indices, dst_indices, cum_sum; - int32_t current_src = swap_blocks[0].device_block_id; + int32_t current_src = swap_blocks[0].src_block_id; src_indices.reserve(swap_blocks.size()); dst_indices.reserve(swap_blocks.size()); - src_indices.push_back(swap_blocks[0].device_block_id); - dst_indices.push_back(swap_blocks[0].host_block_id); + src_indices.push_back(swap_blocks[0].src_block_id); + dst_indices.push_back(swap_blocks[0].dst_block_id); for (size_t i = 1; i < swap_blocks.size(); i++) { - dst_indices.push_back(swap_blocks[i].host_block_id); - if (swap_blocks[i].device_block_id != current_src) { - src_indices.push_back(swap_blocks[i].device_block_id); + dst_indices.push_back(swap_blocks[i].dst_block_id); + if (swap_blocks[i].src_block_id != current_src) { + src_indices.push_back(swap_blocks[i].src_block_id); cum_sum.push_back(i); - current_src = swap_blocks[i].device_block_id; + current_src = swap_blocks[i].src_block_id; } } cum_sum.push_back(swap_blocks.size()); @@ -722,8 +690,8 @@ void BatchInputBuilder::process_swap_block_infos( } } else { raw_forward_input.swap_blocks.insert(raw_forward_input.swap_blocks.end(), - swap_cache_block_infos_->begin(), - swap_cache_block_infos_->end()); + swap_block_transfer_infos_->begin(), + swap_block_transfer_infos_->end()); } } } // namespace xllm diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 9b76bfb1..8774e405 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -37,12 +37,9 @@ class BatchInputBuilder { const std::vector& allowed_max_tokens, const std::vector& input_embeddings_vec, const std::vector& mm_data_vec, - // for global kv cache copy block from host to device - const std::vector* copy_in_cache_block_infos, - // for global kv cache copy block from device to host - const std::vector* copy_out_cache_block_infos, // for beam-search - std::vector* swap_cache_block_infos, + std::vector* swap_block_transfer_infos, + const uint64_t batch_id, const ModelArgs* args, ThreadPool* thread_pool = nullptr); @@ -153,12 +150,11 @@ class BatchInputBuilder { // copy in and out cache contents std::unordered_set write_block_ids_; - const std::vector* copy_in_cache_block_infos_ = nullptr; - const std::vector* copy_out_cache_block_infos_ = nullptr; - std::vector* swap_cache_block_infos_ = nullptr; + std::vector* swap_block_transfer_infos_ = nullptr; // thread pool for multithreaded processing, not owned ThreadPool* thread_pool_ = nullptr; + uint64_t batch_id_ = 0x0; }; } // namespace xllm diff --git a/xllm/core/framework/block/CMakeLists.txt b/xllm/core/framework/block/CMakeLists.txt index 97963eff..013d5849 100644 --- a/xllm/core/framework/block/CMakeLists.txt +++ b/xllm/core/framework/block/CMakeLists.txt @@ -26,6 +26,7 @@ cc_library( SMHasherSupport torch ) +target_link_libraries(block PRIVATE Folly::folly) if(USE_NPU) set(TEST_SRCS diff --git a/xllm/core/framework/block/block_manager.h b/xllm/core/framework/block/block_manager.h index 2ef41471..0fef5e4c 100644 --- a/xllm/core/framework/block/block_manager.h +++ b/xllm/core/framework/block/block_manager.h @@ -62,6 +62,7 @@ class BlockManager { virtual void cache(const Slice& token_ids, std::vector& blocks) = 0; + virtual void cache(const std::vector& blocks) = 0; // get merged all dp rank KVCacheEvent virtual void get_merged_kvcache_event(KvCacheEvent* event) const = 0; diff --git a/xllm/core/framework/block/block_manager_impl.cpp b/xllm/core/framework/block/block_manager_impl.cpp index b15d7c88..b9c1372e 100644 --- a/xllm/core/framework/block/block_manager_impl.cpp +++ b/xllm/core/framework/block/block_manager_impl.cpp @@ -16,6 +16,8 @@ limitations under the License. #include "block_manager_impl.h" +#include + #include "framework/prefix_cache/prefix_cache_factory.h" namespace xllm { @@ -67,8 +69,23 @@ void BlockManagerImpl::deallocate(const Slice& blocks) { if (options_.enable_prefix_cache()) { for (const auto& block : blocks) { // the block is not shared by other sequence - if (block.ref_count() <= 2) { - num_used_blocks_.fetch_sub(1, std::memory_order_relaxed); + if (block.is_valid() && block.ref_count() <= 2) { + if (num_used_blocks_ > 0) { + num_used_blocks_.fetch_sub(1, std::memory_order_relaxed); + } else { + LOG(ERROR) << "num_used_blocks_==0 cannot fetch_sub for id:" + << block.id() + << ", total block size: " << num_total_blocks(); + std::unordered_set block_id_set; + block_id_set.insert(block.id()); + std::string error_msg = "Block already released: "; + for (auto& id : free_blocks_) { + if (block_id_set.count(id) != 0) { + error_msg.append(std::to_string(id)).append(" "); + } + } + LOG(ERROR) << error_msg; + } } } } else { @@ -141,6 +158,14 @@ void BlockManagerImpl::cache(const Slice& token_ids, } } +void BlockManagerImpl::cache(const std::vector& blocks) { + if (options_.enable_prefix_cache()) { + AUTO_COUNTER(prefix_cache_latency_seconds_insert); + // Add the kv cache to the prefix cache + prefix_cache_->insert(blocks); + } +} + void BlockManagerImpl::get_merged_kvcache_event(KvCacheEvent* event) const { auto events = prefix_cache_->get_upload_kvcache_events(); if (events != nullptr) { diff --git a/xllm/core/framework/block/block_manager_impl.h b/xllm/core/framework/block/block_manager_impl.h index 226f5a7c..6a6c9646 100644 --- a/xllm/core/framework/block/block_manager_impl.h +++ b/xllm/core/framework/block/block_manager_impl.h @@ -43,6 +43,7 @@ class BlockManagerImpl : public BlockManager { // cache blocks when enable prefix cache void cache(const Slice& token_ids, std::vector& blocks) override; + void cache(const std::vector& blocks) override; void get_merged_kvcache_event(KvCacheEvent* event) const override; diff --git a/xllm/core/framework/block/block_manager_pool.cpp b/xllm/core/framework/block/block_manager_pool.cpp index 0d8ef693..88476d4c 100644 --- a/xllm/core/framework/block/block_manager_pool.cpp +++ b/xllm/core/framework/block/block_manager_pool.cpp @@ -54,7 +54,10 @@ BlockManagerPool::BlockManagerPool(const Options& options, int32_t dp_size) } } } - reset_copy_content(); + reset_transfer_infos(); + offload_block_transfer_infos_.resize(block_managers_.size()); + released_host_blocks_.resize(block_managers_.size()); + released_device_blocks_.resize(block_managers_.size()); } int32_t BlockManagerPool::get_manager_with_max_free_blocks() const { @@ -115,38 +118,73 @@ void BlockManagerPool::deallocate(Sequence* sequence) { int32_t dp_rank = get_dp_rank(sequence); cache(sequence); if (!host_block_managers_.empty()) { - cache_host(sequence); - host_block_managers_[dp_rank]->deallocate( - sequence->host_kv_state().kv_blocks()); + record_offload_blocks(sequence); } block_managers_[dp_rank]->deallocate(sequence->kv_state().kv_blocks()); // release the blocks after prefix cache insertion sequence->reset(); } -std::vector>* -BlockManagerPool::get_copy_in_cache_block_infos() { - return ©_in_cache_block_infos_; +std::vector>* +BlockManagerPool::get_swap_block_transfer_infos() { + return &swap_block_transfer_infos_; } -std::vector>* -BlockManagerPool::get_copy_out_cache_block_infos() { - return ©_out_cache_block_infos_; +std::vector>* +BlockManagerPool::get_offload_block_transfer_infos() { + return &offload_block_transfer_infos_; } -std::vector>* -BlockManagerPool::get_swap_cache_block_infos() { - return &swap_cache_block_infos_; +std::vector>* +BlockManagerPool::get_load_block_transfer_infos() { + return &load_block_transfer_infos_; } -void BlockManagerPool::reset_copy_content() { - copy_in_cache_block_infos_.clear(); - copy_in_cache_block_infos_.resize(host_block_managers_.size()); - copy_out_cache_block_infos_.clear(); - copy_out_cache_block_infos_.resize(host_block_managers_.size()); - swap_cache_block_infos_.clear(); - swap_cache_block_infos_.resize(block_managers_.size()); - evict_host_blocks_.clear(); +void BlockManagerPool::set_offload_callback( + std::vector>>& futures) { + DCHECK(futures.size() == block_managers_.size()); + for (int i = 0; i < futures.size(); i++) { + if (futures[i].empty()) { + continue; + } + // TODO(kangmeng): add timeout + folly::collectAll(std::move(futures[i])) + .via(folly::getGlobalCPUExecutor()) + .thenValue([host_blocks = std::move(released_host_blocks_[i]), + device_blocks = std::move(released_device_blocks_[i]), + host_block_mgr_ptr = host_block_managers_[i].get(), + device_block_mgr_ptr = block_managers_[i].get()]( + std::vector>&& results) { + for (auto&& result : results) { + try { + if (result.value() != host_blocks.size()) { + LOG(FATAL) << "Offload copy fail, expected " + << host_blocks.size() << ", got " << result.value(); + } + } catch (const std::exception& e) { + LOG(FATAL) << "Offload copy fail! Exception caught: " << e.what(); + } + } + host_block_mgr_ptr->cache(host_blocks); + host_block_mgr_ptr->deallocate({host_blocks}); + device_block_mgr_ptr->deallocate({device_blocks}); + return 0; + }); + } + + offload_block_transfer_infos_.clear(); + released_host_blocks_.clear(); + released_device_blocks_.clear(); + offload_block_transfer_infos_.resize(block_managers_.size()); + released_host_blocks_.resize(block_managers_.size()); + released_device_blocks_.resize(block_managers_.size()); +} + +void BlockManagerPool::reset_transfer_infos() { + swap_block_transfer_infos_.clear(); + swap_block_transfer_infos_.resize(block_managers_.size()); + load_block_transfer_infos_.clear(); + load_block_transfer_infos_.resize(block_managers_.size()); } bool BlockManagerPool::allocate(Sequence* sequence) { @@ -172,9 +210,9 @@ bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) { // first try to allocate shared blocks if (sequence->kv_state().num_kv_blocks() == 0) { allocate_shared(sequence); - } - if (sequence->host_kv_state().num_kv_blocks() == 0) { - allocate_host_shared(sequence); + if (sequence->host_kv_state().num_kv_blocks() == 0) { + allocate_host_shared(sequence); + } } const size_t num_blocks = sequence->kv_state().num_kv_blocks(); @@ -208,10 +246,11 @@ bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) { for (int i = hbm_cache_token_num / options_.block_size(); i < host_cache_token_num / options_.block_size(); i++) { - copy_in_cache_block_infos_[dp_rank].emplace_back( - hbm_blocks[i].id(), - host_blocks[i].id(), - host_blocks[i].get_immutable_hash_value()); + load_block_transfer_infos_[dp_rank].emplace_back( + BlockTransferInfo(host_blocks[i].id(), + hbm_blocks[i].id(), + host_blocks[i].get_immutable_hash_value(), + TransferType::H2D)); } sequence->kv_state().incr_kv_cache_tokens_num(host_cache_token_num - hbm_cache_token_num); @@ -243,8 +282,8 @@ void BlockManagerPool::process_beam_search(Sequence* sequence, bool need_swap) { if (need_swap && sequence->kv_state().need_swap()) { int32_t dp_rank = get_dp_rank(sequence); auto new_blocks = block_managers_[dp_rank]->allocate(1); - swap_cache_block_infos_[dp_rank].emplace_back(src_blocks.back().id(), - new_blocks[0].id()); + swap_block_transfer_infos_[dp_rank].emplace_back(src_blocks.back().id(), + new_blocks[0].id()); sequence->kv_state().process_beam_search(new_blocks); } else { sequence->kv_state().process_beam_search({}); @@ -303,7 +342,7 @@ void BlockManagerPool::cache(Sequence* sequence) { int32_t dp_rank = get_dp_rank(sequence); const auto token_ids = sequence->cached_tokens(); auto* blocks = sequence->kv_state().mutable_kv_blocks(); - return block_managers_[dp_rank]->cache(token_ids, *blocks); + block_managers_[dp_rank]->cache(token_ids, *blocks); } void BlockManagerPool::allocate_host_shared(Sequence* sequence) { @@ -321,37 +360,53 @@ void BlockManagerPool::allocate_host_shared(Sequence* sequence) { } } -void BlockManagerPool::cache_host(Sequence* sequence) { +void BlockManagerPool::record_offload_blocks(Sequence* sequence) { DCHECK(sequence != nullptr); - int32_t dp_rank = get_dp_rank(sequence); - size_t needed_block_num = (sequence->num_tokens() / options_.block_size() - - sequence->host_kv_state().num_kv_blocks()); + auto* blocks = sequence->kv_state().mutable_kv_blocks(); + auto* host_blocks = sequence->host_kv_state().mutable_kv_blocks(); - if (needed_block_num > 0) { - sequence->host_kv_state().add_kv_blocks( - host_block_managers_[dp_rank]->allocate(needed_block_num)); + if (blocks->size() == 0 || host_blocks->size() >= blocks->size()) { + return; } - evict_host_blocks_.emplace_back( - std::move(*sequence->host_kv_state().mutable_kv_blocks())); - std::vector* host_blocks_ptr = &evict_host_blocks_.back(); + int cached_block_num = + sequence->host_kv_state().kv_cache_tokens_num() / options_.block_size(); - auto blocks = sequence->kv_state().kv_blocks(); + int32_t dp_rank = get_dp_rank(sequence); + + if (host_blocks->size() > 0) { + host_block_managers_[dp_rank]->cache(sequence->tokens(), *host_blocks); + } - for (int i = sequence->host_kv_state().kv_cache_tokens_num() / - options_.block_size(); - i < host_blocks_ptr->size(); - i++) { - host_blocks_ptr->at(i).set_hash_value(blocks[i].get_immutable_hash_value()); - copy_out_cache_block_infos_[dp_rank].emplace_back( - blocks[i].id(), - host_blocks_ptr->at(i).id(), - host_blocks_ptr->at(i).get_immutable_hash_value()); + size_t needed_block_num = + sequence->num_tokens() / options_.block_size() - host_blocks->size(); + + if (needed_block_num == 0) { + return; } + sequence->host_kv_state().add_kv_blocks( + host_block_managers_[dp_rank]->allocate(needed_block_num)); + + for (int i = cached_block_num; i < host_blocks->size(); i++) { + if (blocks->at(i).ref_count() != 2) { + continue; + } + + host_blocks->at(i).set_hash_value(blocks->at(i).get_immutable_hash_value()); + released_host_blocks_[dp_rank].emplace_back(std::move(host_blocks->at(i))); + released_device_blocks_[dp_rank].emplace_back(std::move(blocks->at(i))); + offload_block_transfer_infos_[dp_rank].emplace_back(BlockTransferInfo( + released_device_blocks_[dp_rank].back().id(), + released_host_blocks_[dp_rank].back().id(), + released_host_blocks_[dp_rank].back().get_immutable_hash_value(), + TransferType::D2G)); + } host_block_managers_[dp_rank]->cache( - sequence->tokens(), *sequence->host_kv_state().mutable_kv_blocks()); + *sequence->host_kv_state().mutable_kv_blocks()); + host_block_managers_[dp_rank]->deallocate( + sequence->host_kv_state().kv_blocks()); } void BlockManagerPool::get_merged_kvcache_event(KvCacheEvent* event) const { diff --git a/xllm/core/framework/block/block_manager_pool.h b/xllm/core/framework/block/block_manager_pool.h index 745cad31..ac2f6e9f 100644 --- a/xllm/core/framework/block/block_manager_pool.h +++ b/xllm/core/framework/block/block_manager_pool.h @@ -15,6 +15,7 @@ limitations under the License. #pragma once +#include #include #include "block_manager.h" @@ -57,13 +58,15 @@ class BlockManagerPool final : public KVCacheManager { void allocate_shared(Sequence* sequence) override; void cache(Sequence* sequence) override; - std::vector>* get_copy_in_cache_block_infos() + std::vector>* get_swap_block_transfer_infos() override; - std::vector>* get_copy_out_cache_block_infos() + std::vector>* + get_offload_block_transfer_infos() override; + std::vector>* get_load_block_transfer_infos() override; - std::vector>* get_swap_cache_block_infos() - override; - void reset_copy_content() override; + void set_offload_callback( + std::vector>>& futures) override; + void reset_transfer_infos() override; void get_merged_kvcache_event(KvCacheEvent* event) const; float get_gpu_cache_usage_perc() const; @@ -83,7 +86,7 @@ class BlockManagerPool final : public KVCacheManager { int32_t get_dp_rank(Sequence* sequence) const; void allocate_host_shared(Sequence* sequence); - void cache_host(Sequence* sequence); + void record_offload_blocks(Sequence* sequence); void process_beam_search(Sequence* sequence, bool need_swap = false); @@ -94,11 +97,12 @@ class BlockManagerPool final : public KVCacheManager { // the options for the block manager Options options_; - // CacheBlockInfo per step - std::vector> copy_in_cache_block_infos_; - std::vector> copy_out_cache_block_infos_; - std::vector> swap_cache_block_infos_; - std::vector> evict_host_blocks_; + // BlockTransferInfo per step + std::vector> swap_block_transfer_infos_; + std::vector> load_block_transfer_infos_; + std::vector> offload_block_transfer_infos_; + std::vector> released_host_blocks_; + std::vector> released_device_blocks_; }; } // namespace xllm diff --git a/xllm/core/framework/block/concurrent_block_manager_impl.cpp b/xllm/core/framework/block/concurrent_block_manager_impl.cpp index 17f08080..b80267c4 100644 --- a/xllm/core/framework/block/concurrent_block_manager_impl.cpp +++ b/xllm/core/framework/block/concurrent_block_manager_impl.cpp @@ -43,6 +43,11 @@ void ConcurrentBlockManagerImpl::cache(const Slice& token_ids, BlockManagerImpl::cache(token_ids, blocks); } +void ConcurrentBlockManagerImpl::cache(const std::vector& blocks) { + std::lock_guard lock(mutex_); + BlockManagerImpl::cache(blocks); +} + size_t ConcurrentBlockManagerImpl::num_blocks_in_prefix_cache() const { std::lock_guard lock(mutex_); return BlockManagerImpl::num_blocks_in_prefix_cache(); diff --git a/xllm/core/framework/block/concurrent_block_manager_impl.h b/xllm/core/framework/block/concurrent_block_manager_impl.h index 8e7f9cdc..68c2804c 100644 --- a/xllm/core/framework/block/concurrent_block_manager_impl.h +++ b/xllm/core/framework/block/concurrent_block_manager_impl.h @@ -38,6 +38,7 @@ class ConcurrentBlockManagerImpl : public BlockManagerImpl { // cache the blocks void cache(const Slice& token_ids, std::vector& blocks) override; + void cache(const std::vector& blocks) override; // get the number of blocks in the prefix cache size_t num_blocks_in_prefix_cache() const override; diff --git a/xllm/core/framework/block/kv_cache_manager.h b/xllm/core/framework/block/kv_cache_manager.h index 3b1950f4..e3913df4 100644 --- a/xllm/core/framework/block/kv_cache_manager.h +++ b/xllm/core/framework/block/kv_cache_manager.h @@ -42,13 +42,19 @@ class KVCacheManager { virtual void allocate_shared(Sequence* sequence) = 0; virtual void cache(Sequence* sequence) = 0; - virtual std::vector>* - get_copy_in_cache_block_infos() = 0; - virtual std::vector>* - get_copy_out_cache_block_infos() = 0; - virtual std::vector>* - get_swap_cache_block_infos() = 0; - virtual void reset_copy_content() = 0; + virtual std::vector>* + get_swap_block_transfer_infos() = 0; + + virtual std::vector>* + get_offload_block_transfer_infos() = 0; + + virtual std::vector>* + get_load_block_transfer_infos() = 0; + + virtual void set_offload_callback( + std::vector>>& futures) = 0; + + virtual void reset_transfer_infos() = 0; virtual uint32_t num_blocks() const = 0; virtual int32_t block_size() const = 0; diff --git a/xllm/core/framework/kv_cache/kv_cache_store.cpp b/xllm/core/framework/kv_cache/kv_cache_store.cpp index 1b430860..576ceff1 100644 --- a/xllm/core/framework/kv_cache/kv_cache_store.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_store.cpp @@ -11,79 +11,66 @@ namespace xllm { -KVCacheStore::KVCacheStore(const StoreConfig& config, - std::vector* host_kv_caches) - : config_(config), host_kv_caches_(host_kv_caches) { +bool KVCacheStore::init(const StoreConfig& config, + std::vector* host_kv_caches) { + CHECK(!is_initialized_) << "KVCacheStore is initialized."; + config_ = config; + host_kv_caches_ = host_kv_caches; + std::optional device_names = std::nullopt; if (config_.protocol == "rdma") { - if (getenv("DEVICE_NAME")) { - auto name = getenv("DEVICE_NAME"); - LOG(INFO) << "device name: " << name; - args_ = mooncake::rdma_args(name); + if (getenv("DEVICE_NAMES")) { + device_names = getenv("DEVICE_NAMES"); + LOG(INFO) << "device_names: " << device_names.value(); } else { LOG(WARNING) << "env DEVICE_NAME not exist, set protocol as tcp"; config_.protocol = "tcp"; - args_ = nullptr; } } auto client_opt = mooncake::Client::Create(config_.localhost_name, - config_.metadata_connstring, + config_.metadata_server, config_.protocol, - args_, - config_.master_server_entry); + device_names, + config_.master_server_address); rep_config_.replica_num = config_.replica_num; // rep_config_.preferred_segment = config_.localhost_name; if (!client_opt.has_value()) { - LOG(FATAL) << "mooncake::Client::Create fail!"; - return; + LOG(FATAL) << "mooncake::Client::Create fail! Failed to create client with " + "host_name: " + << config_.localhost_name; } client_ptr_ = client_opt.value(); - auto key_tensor_one_layer = host_kv_caches_->at(0).get_k_cache(); - auto value_tensor_one_layer = host_kv_caches_->at(0).get_v_cache(); + auto k_tensor_one_block = host_kv_caches_->at(0).get_k_cache(); + auto v_tensor_one_block = host_kv_caches_->at(0).get_v_cache(); - key_cache_size_per_layer_ = - key_tensor_one_layer[0].numel() * key_tensor_one_layer[0].element_size(); - value_cache_size_per_layer_ = value_tensor_one_layer[0].numel() * - value_tensor_one_layer[0].element_size(); + k_cache_size_per_block_ = + k_tensor_one_block.numel() * k_tensor_one_block.element_size(); + v_cache_size_per_block_ = + v_tensor_one_block.numel() * v_tensor_one_block.element_size(); - auto key_cache_host_size = - key_tensor_one_layer.numel() * key_tensor_one_layer.element_size(); - auto value_cache_host_size = - value_tensor_one_layer.numel() * value_tensor_one_layer.element_size(); - - LOG(INFO) << "key_cache_size_per_layer: " << key_cache_size_per_layer_; - LOG(INFO) << "value_cache_size_per_layer: " << value_cache_size_per_layer_; + LOG(INFO) << "k_cache_size_per_block: " << k_cache_size_per_block_; + LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_; if (config_.protocol == "rdma") { - for (int layer = 0; layer < host_kv_caches_->size(); layer++) { - void* key_cache = static_cast( - host_kv_caches_->at(layer).get_k_cache().data_ptr()); - - auto register_k_result = client_ptr_->RegisterLocalMemory( - key_cache, key_cache_host_size, "cpu:0", false, false); - - if (!register_k_result.has_value()) { - LOG(ERROR) << "Failed to register local memory for key cache: " - << toString(register_k_result.error()); - return; - } - - void* value_cache = static_cast( - host_kv_caches_->at(layer).get_v_cache().data_ptr()); - - auto register_v_result = client_ptr_->RegisterLocalMemory( - value_cache, value_cache_host_size, "cpu:0", false, false); - - if (!register_v_result.has_value()) { - LOG(ERROR) << "Failed to register local memory for value cache: " - << toString(register_v_result.error()); - return; + if (config_.total_size > 0 && config_.tensor_data != nullptr) { + auto result = client_ptr_->RegisterLocalMemory( + config_.tensor_data, config_.total_size, "cpu:0", false, false); + if (!result.has_value()) { + LOG(ERROR) << "Failed to register local memory: " + << toString(result.error()); + return false; } + } else { + LOG(FATAL) << "rdma must RegisterLocalMemory, but got register size: " + << config_.total_size + << ", and data ptr: " << uint64_t(config_.tensor_data); } } + is_initialized_ = true; + return true; } KVCacheStore::~KVCacheStore() { @@ -92,14 +79,17 @@ KVCacheStore::~KVCacheStore() { } } -uint64_t KVCacheStore::batch_put( - const std::vector& cache_block_info) { +uint32_t KVCacheStore::batch_put( + Slice& block_transfer_info) { + if (!is_initialized_) { + return 0; + } std::vector str_keys; std::vector> slices; - str_keys.reserve(cache_block_info.size()); - slices.reserve(cache_block_info.size()); - for (auto block_info : cache_block_info) { + str_keys.reserve(block_transfer_info.size()); + slices.reserve(block_transfer_info.size()); + for (auto block_info : block_transfer_info) { std::string str_key(reinterpret_cast(block_info.hash_key), MURMUR_HASH3_VALUE_LEN); @@ -112,50 +102,42 @@ uint64_t KVCacheStore::batch_put( str_keys.emplace_back(str_key); - std::vector slice; - slice.reserve(host_kv_caches_->size() * 2); - for (int layer = 0; layer < host_kv_caches_->size(); layer++) { - void* key_cache = - static_cast( - host_kv_caches_->at(layer).get_k_cache().data_ptr()) + - block_info.host_block_id * key_cache_size_per_layer_; - slice.emplace_back(mooncake::Slice{key_cache, key_cache_size_per_layer_}); - - void* value_cache = - static_cast( - host_kv_caches_->at(layer).get_v_cache().data_ptr()) + - block_info.host_block_id * value_cache_size_per_layer_; - slice.emplace_back( - mooncake::Slice{value_cache, value_cache_size_per_layer_}); - } - slices.emplace_back(std::move(slice)); + void* k_cache = + host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr(); + void* v_cache = + host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr(); + + slices.emplace_back(std::vector{ + mooncake::Slice{k_cache, k_cache_size_per_block_}, + mooncake::Slice{v_cache, v_cache_size_per_block_}}); } if (str_keys.size() == 0) { - return cache_block_info.size(); + return block_transfer_info.size(); } - uint64_t success_cnt = str_keys.size(); + uint64_t success_cnt = block_transfer_info.size() - str_keys.size(); auto results = client_ptr_->BatchPut(str_keys, slices, rep_config_); for (int i = 0; i < str_keys.size(); i++) { if (!results[i].has_value()) { - success_cnt = i; - // LOG(ERROR) << "success_cnt: " << success_cnt - // << ", failed to BatchPut: " << toString(results[i].error()); break; } + success_cnt++; } return success_cnt; } -uint64_t KVCacheStore::batch_get( - const std::vector& cache_block_info) { +uint32_t KVCacheStore::batch_get( + Slice& block_transfer_info) { + if (!is_initialized_) { + return 0; + } std::unordered_map> slices; std::vector str_keys; - str_keys.reserve(cache_block_info.size()); - for (auto block_info : cache_block_info) { + str_keys.reserve(block_transfer_info.size()); + for (auto block_info : block_transfer_info) { std::string str_key(reinterpret_cast(block_info.hash_key), MURMUR_HASH3_VALUE_LEN); @@ -167,47 +149,38 @@ uint64_t KVCacheStore::batch_get( str_keys.emplace_back(str_key); - slices.insert(std::make_pair(str_key, std::vector())); - - slices[str_key].reserve(host_kv_caches_->size() * 2); - for (int layer = 0; layer < host_kv_caches_->size(); layer++) { - void* key_cache = - static_cast( - host_kv_caches_->at(layer).get_k_cache().data_ptr()) + - block_info.host_block_id * key_cache_size_per_layer_; - slices[str_key].emplace_back( - mooncake::Slice{key_cache, key_cache_size_per_layer_}); - - void* value_cache = - static_cast( - host_kv_caches_->at(layer).get_v_cache().data_ptr()) + - block_info.host_block_id * value_cache_size_per_layer_; - slices[str_key].emplace_back( - mooncake::Slice{value_cache, value_cache_size_per_layer_}); - } + void* k_cache = + host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr(); + void* v_cache = + host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr(); + + slices.insert( + std::make_pair(str_key, + std::vector{ + mooncake::Slice{k_cache, k_cache_size_per_block_}, + mooncake::Slice{v_cache, v_cache_size_per_block_}})); } if (str_keys.size() == 0) { return 0; } - uint64_t success_cnt = str_keys.size(); + uint64_t success_cnt = 0; auto results = client_ptr_->BatchGet(str_keys, slices); for (int i = 0; i < str_keys.size(); i++) { if (!results[i].has_value()) { - success_cnt = i; - // LOG(ERROR) << "success_cnt: " << success_cnt - // << ", failed to BatchGet: " << toString(results[i].error()); break; } + success_cnt++; } return success_cnt; } -uint64_t KVCacheStore::batch_remove( - const std::vector& cache_block_info) { - uint64_t success_cnt = 0; - for (auto block_info : cache_block_info) { +uint32_t KVCacheStore::batch_remove( + Slice& block_transfer_info) { + CHECK(is_initialized_) << "KVCacheStore is not initialized."; + uint32_t success_cnt = 0; + for (auto block_info : block_transfer_info) { std::string str_key(reinterpret_cast(block_info.hash_key), MURMUR_HASH3_VALUE_LEN); str_key.append(std::to_string(config_.tp_rank)); @@ -221,4 +194,19 @@ uint64_t KVCacheStore::batch_remove( return success_cnt; } +uint32_t KVCacheStore::batch_exist(std::vector&& keys) { + if (!is_initialized_) { + return 0; + } + auto exist_vec = client_ptr_->BatchIsExist(std::move(keys)); + uint32_t ret = 0; + for (auto exist : exist_vec) { + if (!exist.has_value() || !exist.value()) { + break; + } + ret++; + } + return ret; +} + } // namespace xllm diff --git a/xllm/core/framework/kv_cache/kv_cache_store.h b/xllm/core/framework/kv_cache/kv_cache_store.h index 7877f895..8cf7b9cf 100644 --- a/xllm/core/framework/kv_cache/kv_cache_store.h +++ b/xllm/core/framework/kv_cache/kv_cache_store.h @@ -8,40 +8,71 @@ #include "common/macros.h" #include "framework/model/model_input_params.h" #include "kv_cache.h" +#include "util/slice.h" namespace xllm { struct StoreConfig { std::string localhost_name = "127.0.0.1"; std::string protocol = "tcp"; - std::string metadata_connstring = ""; - std::string master_server_entry = ""; + std::string metadata_server = ""; + std::string master_server_address = ""; int replica_num = 1; uint32_t tp_rank = 0; + size_t total_size = 0; + void* tensor_data = nullptr; }; class KVCacheStore { public: - KVCacheStore(const StoreConfig& config, - std::vector* host_kv_caches); ~KVCacheStore(); - uint64_t batch_put(const std::vector& cache_block_info); + bool init(const StoreConfig& config, + std::vector* host_kv_caches); - uint64_t batch_get(const std::vector& cache_block_info); + uint32_t batch_put( + const std::vector& block_transfer_info) { + return batch_put({block_transfer_info}); + } - uint64_t batch_remove(const std::vector& cache_block_info); + uint32_t batch_get( + const std::vector& block_transfer_info) { + return batch_get({block_transfer_info}); + } + + uint32_t batch_remove( + const std::vector& block_transfer_info) { + return batch_remove({block_transfer_info}); + } + + uint32_t batch_put(Slice& block_transfer_info); + + uint32_t batch_get(Slice& block_transfer_info); + + uint32_t batch_remove(Slice& block_transfer_info); + + uint32_t batch_exist(std::vector&& keys); + + static KVCacheStore& get_instance() { + static KVCacheStore kvcache_store; + return kvcache_store; + } + + private: + KVCacheStore() = default; + KVCacheStore(const KVCacheStore&) = delete; + KVCacheStore& operator=(const KVCacheStore&) = delete; private: + bool is_initialized_ = false; + StoreConfig config_; mooncake::ReplicateConfig rep_config_; - void** args_ = nullptr; - std::vector* host_kv_caches_; - uint64_t key_cache_size_per_layer_; - uint64_t value_cache_size_per_layer_; + uint64_t k_cache_size_per_block_; + uint64_t v_cache_size_per_block_; std::shared_ptr client_ptr_; }; diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 6669baaa..43d54799 100755 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -26,24 +26,56 @@ limitations under the License. #include "util/tensor_helper.h" namespace xllm { -struct CacheBlockInfo { - int32_t device_block_id = 0; - int32_t host_block_id = 0; - uint8_t* hash_key = nullptr; - CacheBlockInfo() {} +enum class TransferType : uint8_t { G2H = 0, H2D = 1, D2G = 2 }; + +struct BlockTransferInfo { + int32_t src_block_id = -1; + int32_t dst_block_id = -1; + uint8_t* hash_key = nullptr; + TransferType transfer_type; + uint32_t hash_key_len = -1; - CacheBlockInfo(int32_t device_block_id, int32_t host_block_id) { - this->device_block_id = device_block_id; - this->host_block_id = host_block_id; + BlockTransferInfo(int32_t src_block_id, int32_t dst_block_id) { + this->src_block_id = src_block_id; + this->dst_block_id = dst_block_id; } - CacheBlockInfo(int32_t device_block_id, - int32_t host_block_id, - const uint8_t* hash_key) { - this->device_block_id = device_block_id; - this->host_block_id = host_block_id; + BlockTransferInfo(int32_t src_block_id, + int32_t dst_block_id, + const uint8_t* hash_key, + TransferType transfer_type) { + this->src_block_id = src_block_id; + this->dst_block_id = dst_block_id; this->hash_key = const_cast(hash_key); + this->transfer_type = transfer_type; + } + + BlockTransferInfo(int32_t src_block_id, + int32_t dst_block_id, + const uint8_t* hash_key, + uint32_t hash_key_len, + TransferType transfer_type) { + this->src_block_id = src_block_id; + this->dst_block_id = dst_block_id; + this->hash_key = new uint8_t[hash_key_len]; + memcpy(this->hash_key, hash_key, hash_key_len); + this->transfer_type = transfer_type; + } + + ~BlockTransferInfo() { + if (hash_key_len != -1 && hash_key != nullptr) { + delete[] hash_key; + } + } + + std::string to_string() const { + std::string rt = ", has_key:"; + for (int i = 0; i < 16; i++) { + rt += std::to_string(int64_t(*(hash_key + i))) + " "; + } + return std::to_string(src_block_id) + "->" + std::to_string(dst_block_id) + + ", " + std::to_string(uint32_t(transfer_type)) + rt; } }; @@ -81,9 +113,6 @@ struct ModelInputParams { #endif params.expert_load_data = expert_load_data; - params.async_copy_out_blocks = std::move(async_copy_out_blocks); - params.copy_out_blocks = std::move(copy_out_blocks); - params.copy_in_blocks = std::move(copy_in_blocks); params.swap_blocks = std::move(swap_blocks); params.src_block_indices = safe_to(src_block_indices, device, true); @@ -97,6 +126,8 @@ struct ModelInputParams { // Copy graph_buffer to device params.graph_buffer = safe_to(graph_buffer, device, true); + params.batch_id = batch_id; + return params; } @@ -172,11 +203,8 @@ struct ModelInputParams { // extra token ids for each sequence, and -1 for last chunk std::vector extra_token_ids; - // copy in / copy out - std::vector async_copy_out_blocks; - std::vector copy_out_blocks; - std::vector copy_in_blocks; - std::vector swap_blocks; + // swap + std::vector swap_blocks; // block copy kernel torch::Tensor src_block_indices; @@ -185,6 +213,8 @@ struct ModelInputParams { #if defined(USE_NPU) std::shared_ptr layer_synchronizer = nullptr; + std::shared_ptr layer_wise_load_synchronizer = + nullptr; #endif DpEpPaddingData dp_ep_padding_data; @@ -201,6 +231,8 @@ struct ModelInputParams { // Graph execution buffer for temporary tensor storage // Used by ACL Graph Executor to avoid repeated memory allocation torch::Tensor graph_buffer; + + uint64_t batch_id; }; } // namespace xllm diff --git a/xllm/core/framework/prefix_cache/prefix_cache.cpp b/xllm/core/framework/prefix_cache/prefix_cache.cpp index fac8ccfb..baadc3f3 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache.cpp +++ b/xllm/core/framework/prefix_cache/prefix_cache.cpp @@ -124,6 +124,11 @@ size_t PrefixCache::insert(const Slice& token_ids, return insert(token_ids, blocks, &insert_keys); } +size_t PrefixCache::insert(const std::vector& blocks) { + std::vector insert_keys; + return insert(blocks, &insert_keys); +} + size_t PrefixCache::evict(size_t n_blocks) { std::vector evict_keys; return evict(n_blocks, &evict_keys); @@ -192,6 +197,49 @@ size_t PrefixCache::insert(const Slice& token_ids, return n_tokens; } +size_t PrefixCache::insert(const std::vector& blocks, + std::vector* insert_keys) { + const int64_t now = absl::ToUnixMicros(absl::Now()); + DNodeList node_list; + Murmur3Key token_hash_key; + + insert_keys->reserve(blocks.size()); + for (size_t i = 0; i < blocks.size(); i++) { + if (!blocks[i].is_valid()) { + continue; + } + token_hash_key.set(blocks[i].get_immutable_hash_value()); + + auto iter = cached_blocks_.find(token_hash_key); + if (iter != cached_blocks_.end()) { + iter->second->last_access_time = now; + + lru_lst_.remove_node(iter->second); + node_list.push_front(iter->second); + } else { + Node* new_node = new Node(); + + new_node->block = blocks[i]; + new_node->last_access_time = now; + + node_list.push_front(new_node); + + cached_blocks_.emplace(std::make_pair(token_hash_key, new_node)); + + num_blocks_++; + + insert_keys->emplace_back(token_hash_key.data); + } + } + + while (!node_list.is_empty()) { + Node* node = node_list.pop_front(); + lru_lst_.push_back(node); + } + + return blocks.size() * block_size_; +} + size_t PrefixCache::evict(size_t n_blocks, std::vector* evict_keys) { if (num_blocks_ == 0 || lru_lst_.is_empty()) { diff --git a/xllm/core/framework/prefix_cache/prefix_cache.h b/xllm/core/framework/prefix_cache/prefix_cache.h index fc778419..48a4e9ba 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache.h +++ b/xllm/core/framework/prefix_cache/prefix_cache.h @@ -68,6 +68,7 @@ class PrefixCache { virtual size_t insert(const Slice& token_ids, std::vector& blocks); + virtual size_t insert(const std::vector& blocks); // evict blocks hold by the prefix cache // return the actual number of evicted blocks @@ -100,6 +101,8 @@ class PrefixCache { size_t insert(const Slice& token_ids, std::vector& blocks, std::vector* insert_keys); + size_t insert(const std::vector& blocks, + std::vector* insert_keys); size_t evict(size_t n_blocks, std::vector* evict_keys); struct Node { diff --git a/xllm/core/framework/request/request.h b/xllm/core/framework/request/request.h index 168bf26c..27cb1b2f 100644 --- a/xllm/core/framework/request/request.h +++ b/xllm/core/framework/request/request.h @@ -73,7 +73,12 @@ class Request : public RequestBase { size_t total_num_blocks(); - void set_preempted() { state_.preempted = true; } + void set_preempted() { + state_.preempted = true; + for (auto& seq : sequences_group_->sequences()) { + seq->preempted(); + } + } bool preempted() const { return state_.preempted; } diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 9a7d2ce9..c9fd646b 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -383,6 +383,7 @@ void Sequence::reset() { kv_state_.reset(); host_kv_state_.reset(); volatile_num_prompt_tokens_ = num_tokens_; + preempted_ = false; } void Sequence::add_shared_kv_blocks(std::vector&& blocks) { @@ -456,4 +457,21 @@ Slice Sequence::get_generated_tokens() const { return {tokens_.data(), 0}; } +void Sequence::update_prefetch_result() { + if (prefetch_results_.empty()) { + return; + } + + termination_flag_.store(true, std::memory_order_release); + uint32_t success_cnt = host_kv_state_.kv_blocks().size(); + for (auto& cnt : prefetch_results_) { + success_cnt = std::min(success_cnt, cnt->load()); + } + if (success_cnt > 0) { + host_kv_state_.incr_kv_cache_tokens_num( + success_cnt * host_kv_state_.kv_blocks()[0].size()); + } + prefetch_results_.clear(); +} + } // namespace xllm diff --git a/xllm/core/framework/request/sequence.h b/xllm/core/framework/request/sequence.h index 846c037b..4eaf8399 100644 --- a/xllm/core/framework/request/sequence.h +++ b/xllm/core/framework/request/sequence.h @@ -192,6 +192,9 @@ class Sequence final { void close() { closed_ = true; } bool is_closed() const { return closed_; } + void preempted() { preempted_ = true; } + bool is_preempted() const { return preempted_; } + // time between two tokens int64_t tbt(const absl::Time& now); // set sequence ttft @@ -230,26 +233,12 @@ class Sequence final { const Tokenizer& tokenizer, std::optional>& out_logprobs); - void set_async_result(std::vector>&& futures) { - futures_ = std::move(futures); + const std::atomic& get_termination_flag() { return termination_flag_; } + std::vector>>* get_prefetch_results() { + return &prefetch_results_; } - void sync_result() { - if (futures_.has_value()) { - auto success_cnt = host_kv_state_.num_kv_blocks(); - for (auto& future : futures_.value()) { - if (future.isReady()) { - success_cnt = std::min(success_cnt, size_t(future.value())); - } else { - return; - } - } - if (success_cnt > 0) { - host_kv_state_.incr_kv_cache_tokens_num( - success_cnt * host_kv_state_.kv_blocks()[0].size()); - } - } - } + void update_prefetch_result(); void reset(); @@ -335,6 +324,9 @@ class Sequence final { // is the sequence closed. bool closed_ = false; + // is the sequence preempted. + bool preempted_ = false; + // dp_rank int32_t dp_rank_ = -1; @@ -355,7 +347,8 @@ class Sequence final { std::queue is_pre_scheduled_step_prefill_; // kvcache store copy async result - std::optional>> futures_; + std::atomic termination_flag_{false}; + std::vector>> prefetch_results_; }; } // namespace xllm diff --git a/xllm/core/framework/xtensor/xtensor_manager_pool.h b/xllm/core/framework/xtensor/xtensor_manager_pool.h index fe98efab..c706eee8 100644 --- a/xllm/core/framework/xtensor/xtensor_manager_pool.h +++ b/xllm/core/framework/xtensor/xtensor_manager_pool.h @@ -55,26 +55,34 @@ class XTensorManagerPool final : public KVCacheManager { LOG(FATAL) << "allocate_shared is not implemented for page manager pool"; } - std::vector>* get_copy_in_cache_block_infos() + std::vector>* + get_offload_block_transfer_infos() override { + LOG(FATAL) + << "get_offload_block_transfer_infos is not implemented for page " + "manager pool"; + } + + std::vector>* get_load_block_transfer_infos() override { - LOG(FATAL) << "get_copy_in_cache_block_infos is not implemented for page " + LOG(FATAL) << "get_load_block_transfer_infos is not implemented for page " "manager pool"; } - std::vector>* get_copy_out_cache_block_infos() + std::vector>* get_swap_block_transfer_infos() override { - LOG(FATAL) << "get_copy_out_cache_block_infos is not implemented for page " + LOG(FATAL) << "get_swap_block_transfer_infos is not implemented for page " "manager pool"; } - std::vector>* get_swap_cache_block_infos() - override { - LOG(FATAL) << "get_swap_cache_block_infos is not implemented for page " + void set_offload_callback( + std::vector>>& futures) override { + LOG(FATAL) << "set_offload_callback is not implemented for page " "manager pool"; } - void reset_copy_content() override { - LOG(FATAL) << "reset_copy_content is not implemented for page manager pool"; + void reset_transfer_infos() override { + LOG(FATAL) + << "reset_transfer_infos is not implemented for page manager pool"; } uint32_t num_blocks() const override { diff --git a/xllm/core/platform/device.cpp b/xllm/core/platform/device.cpp index 6c3763c6..e99cd38d 100644 --- a/xllm/core/platform/device.cpp +++ b/xllm/core/platform/device.cpp @@ -96,8 +96,8 @@ int Device::synchronize_default_stream() { #endif } -std::unique_ptr Device::get_stream_from_pool() { - return std::make_unique(); +std::unique_ptr Device::get_stream_from_pool(const int32_t timeout) { + return std::make_unique(timeout); } } // namespace xllm diff --git a/xllm/core/platform/device.h b/xllm/core/platform/device.h index 65c5b5a6..3e24dec9 100644 --- a/xllm/core/platform/device.h +++ b/xllm/core/platform/device.h @@ -44,7 +44,7 @@ class Device { int64_t free_memory(); int synchronize_default_stream(); - std::unique_ptr get_stream_from_pool(); + std::unique_ptr get_stream_from_pool(const int32_t timeout = -1); private: struct DeviceMem { diff --git a/xllm/core/platform/npu/npu_layer_synchronizer.cpp b/xllm/core/platform/npu/npu_layer_synchronizer.cpp index 017b5315..459cfd2a 100644 --- a/xllm/core/platform/npu/npu_layer_synchronizer.cpp +++ b/xllm/core/platform/npu/npu_layer_synchronizer.cpp @@ -19,8 +19,11 @@ limitations under the License. namespace xllm { -NPULayerSynchronizerImpl::NPULayerSynchronizerImpl(const int64_t num_layers) - : events_(num_layers, nullptr), event_record_flags_(num_layers) { +NPULayerSynchronizerImpl::NPULayerSynchronizerImpl(const int64_t num_layers, + const int32_t timeout) + : events_(num_layers, nullptr), + event_record_flags_(num_layers), + timeout_(timeout) { uint32_t flags = ACL_EVENT_SYNC; for (int64_t i = 0; i < num_layers; ++i) { auto ret = aclrtCreateEventWithFlag(&events_[i], flags); @@ -45,9 +48,9 @@ std::atomic* NPULayerSynchronizerImpl::get_event_flag( bool NPULayerSynchronizerImpl::synchronize_layer(const int64_t layer_index) { while (!event_record_flags_[layer_index].load(std::memory_order_acquire)); - auto ret = aclrtSynchronizeEvent(events_[layer_index]); + auto ret = aclrtSynchronizeEventWithTimeout(events_[layer_index], timeout_); if (ret != ACL_SUCCESS) { - LOG(ERROR) << "Synchronize event failed."; + LOG(ERROR) << "Synchronize event failed: " << ret; return false; } return true; diff --git a/xllm/core/platform/npu/npu_layer_synchronizer.h b/xllm/core/platform/npu/npu_layer_synchronizer.h index 83b9af2b..6bf957b1 100644 --- a/xllm/core/platform/npu/npu_layer_synchronizer.h +++ b/xllm/core/platform/npu/npu_layer_synchronizer.h @@ -24,7 +24,8 @@ namespace xllm { class NPULayerSynchronizerImpl { public: - NPULayerSynchronizerImpl(const int64_t num_layers); + NPULayerSynchronizerImpl(const int64_t num_layers, + const int32_t timeout = -1); virtual ~NPULayerSynchronizerImpl(); aclrtEvent* get_event(const int64_t layer_index); @@ -34,6 +35,7 @@ class NPULayerSynchronizerImpl { private: std::vector events_; std::vector> event_record_flags_; + const int32_t timeout_; }; } // namespace xllm diff --git a/xllm/core/platform/stream.cpp b/xllm/core/platform/stream.cpp index 5cb15b48..a1ca0e74 100644 --- a/xllm/core/platform/stream.cpp +++ b/xllm/core/platform/stream.cpp @@ -18,14 +18,16 @@ limitations under the License. namespace xllm { #if defined(USE_NPU) -Stream::Stream() : stream_(c10_npu::getNPUStreamFromPool()) {} +Stream::Stream(const int32_t timeout) + : stream_(c10_npu::getNPUStreamFromPool()), timeout_(timeout) {} #elif defined(USE_MLU) -Stream::Stream() : stream_(torch_mlu::getStreamFromPool()) {} +Stream::Stream(const int32_t timeout) + : stream_(torch_mlu::getStreamFromPool()), timeout_(timeout) {} #endif int Stream::synchronize() const { #if defined(USE_NPU) - return aclrtSynchronizeStream(stream_.stream()); + return aclrtSynchronizeStreamWithTimeout(stream_.stream(), timeout_); #elif defined(USE_MLU) stream_.unwrap().synchronize(); return 0; diff --git a/xllm/core/platform/stream.h b/xllm/core/platform/stream.h index 7cb65913..ae3722c7 100644 --- a/xllm/core/platform/stream.h +++ b/xllm/core/platform/stream.h @@ -34,7 +34,7 @@ namespace xllm { class Stream { public: - Stream(); + Stream(const int32_t timeout = -1); ~Stream() = default; Stream(const Stream&) = delete; @@ -44,6 +44,11 @@ class Stream { int synchronize() const; c10::StreamGuard set_stream_guard() const; +#if defined(USE_NPU) + c10_npu::NPUStream* get_stream() { return &stream_; } +#elif defined(USE_MLU) + torch_mlu::MLUStream* get_stream() { return &stream_; } +#endif private: #if defined(USE_NPU) @@ -51,6 +56,7 @@ class Stream { #elif defined(USE_MLU) torch_mlu::MLUStream stream_; #endif + const int32_t timeout_; }; } // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/engine.h b/xllm/core/runtime/engine.h index fd860050..ffc7bcaa 100644 --- a/xllm/core/runtime/engine.h +++ b/xllm/core/runtime/engine.h @@ -81,11 +81,25 @@ class Engine { LOG(FATAL) << " pull_kv_blocks is notimplemented!"; }; - virtual std::vector> - load_kv_blocks_from_store_async( + virtual std::vector> transfer_kv_blocks( const uint32_t dp_rank, - const std::vector& cache_block_info) { - LOG(FATAL) << " load_kv_blocks_from_store is not implemented!"; + const std::vector& block_transfer_info) { + LOG(FATAL) << " transfer_kv_blocks is not implemented!"; + }; + + virtual void transfer_kv_blocks( + const uint32_t dp_rank, + const uint64_t batch_id, + const std::vector& block_transfer_info) { + LOG(FATAL) << " transfer_kv_blocks is not implemented!"; + }; + + virtual void prefetch_from_storage( + const uint32_t dp_rank, + const std::atomic& flag, + const std::vector& block_transfer_info, + std::vector>>* prefetch_results) { + LOG(FATAL) << " prefetch_from_storage is not implemented!"; }; virtual void get_device_info(std::vector& device_ips, diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index dd4a3d8f..0c3723d7 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -166,11 +166,9 @@ struct RawForwardInput { uint32_t prefill_seq_len; // embedding ids of each sequence std::vector embedding_ids; - // copy in / copy out - std::vector async_copy_out_blocks; - std::vector copy_out_blocks; - std::vector copy_in_blocks; - std::vector swap_blocks; + // swap + std::vector swap_blocks; + uint64_t batch_id; // block copy kernel std::vector src_block_indices; std::vector dst_block_indices; diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 27b7161d..067e755c 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -45,9 +45,8 @@ constexpr size_t sampling_param_fixed_size() { + type_size; // beam_width } -constexpr size_t cache_block_info_fixed_size() { - return type_size * 2 + - 16; // device_block_id + host_block_id + hash_key +constexpr size_t swap_block_info_fixed_size() { + return type_size * 2; // src_block_id + dst_block_id } INLINE size_t get_string_size(const std::string& str) { @@ -138,11 +137,8 @@ INLINE size_t calculate_raw_forward_input_size(const RawForwardInput& input) { total += get_transfer_kv_info_size(t); } - const size_t cache_block_size = - input.async_copy_out_blocks.size() + input.copy_out_blocks.size() + - input.copy_in_blocks.size() + input.swap_blocks.size(); - total += type_size * 4 + - cache_block_size * cache_block_info_fixed_size(); + total += type_size + + input.swap_blocks.size() * swap_block_info_fixed_size(); total += type_size * 2 // empty_kv_cache + global_empty_kv_cache + type_size * @@ -256,30 +252,14 @@ INLINE void write_eplb_info(char*& buffer, const EplbInfo& info) { write_data(buffer, info.update_layer_id); } -INLINE void write_cache_block_info(char*& buffer, const CacheBlockInfo& info) { - *reinterpret_cast(buffer) = info.device_block_id; - *reinterpret_cast(buffer + 4) = info.host_block_id; - if (info.hash_key) { - std::memcpy(buffer + 8, info.hash_key, 16); - } else { - std::memset(buffer + 8, 0, 16); - } - buffer += cache_block_info_fixed_size(); -} - -INLINE void write_cache_blocks(char*& buffer, - const std::vector& blocks) { +INLINE void write_swap_blocks(char*& buffer, + const std::vector& blocks) { write_data(buffer, (uint64_t)blocks.size()); - if constexpr (sizeof(CacheBlockInfo) == cache_block_info_fixed_size()) { - if (!blocks.empty()) { - std::memcpy( - buffer, blocks.data(), blocks.size() * cache_block_info_fixed_size()); - buffer += blocks.size() * cache_block_info_fixed_size(); - } - } else { - for (const auto& b : blocks) { - write_cache_block_info(buffer, b); - } + + for (const auto& b : blocks) { + *reinterpret_cast(buffer) = b.src_block_id; + *reinterpret_cast(buffer + 4) = b.src_block_id; + buffer += swap_block_info_fixed_size(); } } @@ -394,22 +374,14 @@ INLINE void read_eplb_info(const char*& buffer, EplbInfo& info) { read_data(buffer, info.update_layer_id); } -INLINE void read_cache_block_info(const char*& buffer, CacheBlockInfo& info) { - info.device_block_id = *reinterpret_cast(buffer); - info.host_block_id = *reinterpret_cast(buffer + 4); - // notice: a temporary pointer in the buffer is stored here - info.hash_key = - const_cast(reinterpret_cast(buffer + 8)); - buffer += 8 + 16; -} - -INLINE void read_cache_blocks(const char*& buffer, - std::vector& blocks) { +INLINE void read_swap_blocks(const char*& buffer, + std::vector& blocks) { uint64_t size; read_data(buffer, size); - blocks.resize(size); - for (auto& block : blocks) { - read_cache_block_info(buffer, block); + blocks.reserve(size); + for (int i = 0; i < size; i++) { + blocks.emplace_back(*reinterpret_cast(buffer), + *reinterpret_cast(buffer + 4)); } } @@ -457,10 +429,8 @@ INLINE void deserialize_raw_forward_input( read_transfer_kv_info(buffer, transfer); } - read_cache_blocks(buffer, input.async_copy_out_blocks); - read_cache_blocks(buffer, input.copy_out_blocks); - read_cache_blocks(buffer, input.copy_in_blocks); - read_cache_blocks(buffer, input.swap_blocks); + read_swap_blocks(buffer, input.swap_blocks); + read_data(buffer, input.batch_id); read_data(buffer, input.empty_kv_cache); read_data(buffer, input.global_empty_kv_cache); @@ -509,10 +479,8 @@ INLINE void serialize_raw_forward_input(const RawForwardInput& input, write_transfer_kv_info(buffer, t); } - write_cache_blocks(buffer, input.async_copy_out_blocks); - write_cache_blocks(buffer, input.copy_out_blocks); - write_cache_blocks(buffer, input.copy_in_blocks); - write_cache_blocks(buffer, input.swap_blocks); + write_swap_blocks(buffer, input.swap_blocks); + write_data(buffer, input.batch_id); *reinterpret_cast(buffer) = input.empty_kv_cache; buffer += 1; @@ -746,11 +714,8 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, input_params.cum_sum = torch::tensor(std::move(raw_input.cum_sum), tensor_options); - input_params.async_copy_out_blocks = - std::move(raw_input.async_copy_out_blocks); - input_params.copy_out_blocks = std::move(raw_input.copy_out_blocks); - input_params.copy_in_blocks = std::move(raw_input.copy_in_blocks); input_params.swap_blocks = std::move(raw_input.swap_blocks); + input_params.batch_id = std::move(raw_input.batch_id); input_params.extra_token_ids = std::move(raw_input.extra_token_ids); input_params.new_cache_slot_offsets = torch::tensor( diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index b2eddb5c..87071c5a 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -467,21 +467,44 @@ bool LLMEngine::pull_kv_blocks(const int32_t src_dp_size, return true; } -std::vector> -LLMEngine::load_kv_blocks_from_store_async( +std::vector> LLMEngine::transfer_kv_blocks( const uint32_t dp_rank, - const std::vector& cache_block_info) { + const std::vector& block_transfer_info) { std::vector> futures; - futures.reserve(dp_local_tp_size_); + for (auto tp_rank = 0; tp_rank < dp_local_tp_size_; ++tp_rank) { - futures.emplace_back( - worker_clients_[tp_rank + dp_local_tp_size_ * dp_rank] - ->load_kv_blocks_from_store_async(cache_block_info)); + futures.emplace_back(worker_clients_[tp_rank + dp_local_tp_size_ * dp_rank] + ->transfer_kv_blocks(block_transfer_info)); } + return std::move(futures); } +void LLMEngine::transfer_kv_blocks( + const uint32_t dp_rank, + const uint64_t batch_id, + const std::vector& block_transfer_info) { + for (auto tp_rank = 0; tp_rank < dp_local_tp_size_; ++tp_rank) { + worker_clients_[tp_rank + dp_local_tp_size_ * dp_rank]->transfer_kv_blocks( + batch_id, block_transfer_info); + } +} + +void LLMEngine::prefetch_from_storage( + const uint32_t dp_rank, + const std::atomic& flag, + const std::vector& block_transfer_info, + std::vector>>* prefetch_results) { + prefetch_results->resize(dp_local_tp_size_, + std::make_shared>(0)); + for (auto tp_rank = 0; tp_rank < dp_local_tp_size_; ++tp_rank) { + worker_clients_[tp_rank + dp_local_tp_size_ * dp_rank] + ->prefetch_from_storage( + flag, block_transfer_info, prefetch_results->at(tp_rank)); + } +} + void LLMEngine::get_device_info(std::vector& device_ips, std::vector& ports) { if (worker_device_ips_.size() != worker_clients_num_ || diff --git a/xllm/core/runtime/llm_engine.h b/xllm/core/runtime/llm_engine.h index 8d6c083c..b6be9c3b 100644 --- a/xllm/core/runtime/llm_engine.h +++ b/xllm/core/runtime/llm_engine.h @@ -71,9 +71,21 @@ class LLMEngine : public Engine { const int32_t dst_dp_rank, const std::vector& dst_blocks) override; - std::vector> load_kv_blocks_from_store_async( + std::vector> transfer_kv_blocks( const uint32_t dp_rank, - const std::vector& cache_block_info) override; + const std::vector& block_transfer_info) override; + + void transfer_kv_blocks( + const uint32_t dp_rank, + const uint64_t batch_id, + const std::vector& block_transfer_info) override; + + void prefetch_from_storage( + const uint32_t dp_rank, + const std::atomic& flag, + const std::vector& block_transfer_info, + std::vector>>* prefetch_results) + override; void get_device_info(std::vector& device_ips, std::vector& ports) override; diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp index 0ade6f56..96dd4705 100644 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -198,8 +198,9 @@ Master::Master(const Options& options, EngineType type) : options_(options) { .host_blocks_factor(options_.host_blocks_factor()) .enable_kvcache_store(options_.enable_kvcache_store()) .store_protocol(options_.store_protocol()) - .store_master_server_entry(options_.store_master_server_entry()) - .store_metadata_connstring(options_.store_metadata_connstring()) + .store_master_server_address(options_.store_master_server_address()) + .store_metadata_server(options_.store_metadata_server()) + .store_local_hostname(options_.store_local_hostname()) .enable_continuous_kvcache(options_.enable_continuous_kvcache()) .enable_offline_inference(options_.enable_offline_inference()) .spawn_worker_path(options_.spawn_worker_path()) diff --git a/xllm/core/runtime/options.h b/xllm/core/runtime/options.h index d61ad94b..f2228449 100644 --- a/xllm/core/runtime/options.h +++ b/xllm/core/runtime/options.h @@ -145,11 +145,15 @@ struct Options { // The address information of the Master (IP:Port for default mode and // etcd://IP:Port;IP:Port;...;IP:Port for high availability mode) - PROPERTY(std::string, store_master_server_entry) = ""; + PROPERTY(std::string, store_master_server_address) = ""; // the address of the metadata service (e.g., etcd/Redis) required for // Transfer Engine initialization - PROPERTY(std::string, store_metadata_connstring) = ""; + PROPERTY(std::string, store_metadata_server) = ""; + + // the IP:Port of the local machine or an accessible domain name (default + // value used if port is not included) + PROPERTY(std::string, store_local_hostname) = ""; // dit // max requests per batch diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 428c0c3e..77f64eca 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -116,36 +116,10 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector(pb_forward_input->extra_token_ids().begin(), pb_forward_input->extra_token_ids().end()); - std::vector async_copy_out_blocks; - for (size_t i = 0; i < pb_forward_input->async_copy_out_blocks().size(); - ++i) { - async_copy_out_blocks.emplace_back( - pb_forward_input->async_copy_out_blocks()[i].device_block_id(), - pb_forward_input->async_copy_out_blocks()[i].host_block_id(), - reinterpret_cast( - pb_forward_input->async_copy_out_blocks()[i].hash_key().data())); - } - std::vector copy_out_blocks; - for (size_t i = 0; i < pb_forward_input->copy_out_blocks().size(); ++i) { - copy_out_blocks.emplace_back( - pb_forward_input->copy_out_blocks()[i].device_block_id(), - pb_forward_input->copy_out_blocks()[i].host_block_id(), - reinterpret_cast( - pb_forward_input->copy_out_blocks()[i].hash_key().data())); - } - std::vector copy_in_blocks; - for (size_t i = 0; i < pb_forward_input->copy_in_blocks().size(); ++i) { - copy_in_blocks.emplace_back( - pb_forward_input->copy_in_blocks()[i].device_block_id(), - pb_forward_input->copy_in_blocks()[i].host_block_id(), - reinterpret_cast( - pb_forward_input->copy_in_blocks()[i].hash_key().data())); - } - std::vector swap_blocks; + std::vector swap_blocks; for (size_t i = 0; i < pb_forward_input->swap_blocks().size(); ++i) { - swap_blocks.emplace_back( - pb_forward_input->swap_blocks()[i].device_block_id(), - pb_forward_input->swap_blocks()[i].host_block_id()); + swap_blocks.emplace_back(pb_forward_input->swap_blocks()[i].src_block_id(), + pb_forward_input->swap_blocks()[i].dst_block_id()); } // block copy kernel std::vector src_block_indices = @@ -225,9 +199,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.embedding_ids = std::move(embedding_ids); input_params.extra_token_ids = std::move(extra_token_ids); - input_params.async_copy_out_blocks = std::move(async_copy_out_blocks); - input_params.copy_out_blocks = std::move(copy_out_blocks); - input_params.copy_in_blocks = std::move(copy_in_blocks); input_params.swap_blocks = std::move(swap_blocks); // block copy kernel input_params.src_block_indices = @@ -236,6 +207,8 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, torch::tensor(dst_block_indices, tensor_options); input_params.cum_sum = torch::tensor(cum_sum, tensor_options); + input_params.batch_id = pb_forward_input->batch_id(); + if (pb_forward_input->embeds().size() > 0) { const int32_t rows = pb_forward_input->embeds().size(); const int32_t cols = pb_forward_input->embeds()[0].vals().size(); @@ -461,41 +434,14 @@ void forward_input_to_proto(const RawForwardInput& inputs, inputs.embedding_ids); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_extra_token_ids(), inputs.extra_token_ids); - pb_forward_input->mutable_async_copy_out_blocks()->Reserve( - inputs.async_copy_out_blocks.size()); - for (auto t : inputs.async_copy_out_blocks) { - proto::CacheBlockInfo cache_block_info; - cache_block_info.set_device_block_id(t.device_block_id); - cache_block_info.set_host_block_id(t.host_block_id); - cache_block_info.set_hash_key(t.hash_key, MURMUR_HASH3_VALUE_LEN); - *pb_forward_input->mutable_async_copy_out_blocks()->Add() = - cache_block_info; - } - pb_forward_input->mutable_copy_out_blocks()->Reserve( - inputs.copy_out_blocks.size()); - for (auto t : inputs.copy_out_blocks) { - proto::CacheBlockInfo cache_block_info; - cache_block_info.set_device_block_id(t.device_block_id); - cache_block_info.set_host_block_id(t.host_block_id); - cache_block_info.set_hash_key(t.hash_key, MURMUR_HASH3_VALUE_LEN); - *pb_forward_input->mutable_copy_out_blocks()->Add() = cache_block_info; - } - pb_forward_input->mutable_copy_in_blocks()->Reserve( - inputs.copy_in_blocks.size()); - for (auto t : inputs.copy_in_blocks) { - proto::CacheBlockInfo cache_block_info; - cache_block_info.set_device_block_id(t.device_block_id); - cache_block_info.set_host_block_id(t.host_block_id); - cache_block_info.set_hash_key(t.hash_key, MURMUR_HASH3_VALUE_LEN); - *pb_forward_input->mutable_copy_in_blocks()->Add() = cache_block_info; - } pb_forward_input->mutable_swap_blocks()->Reserve(inputs.swap_blocks.size()); for (auto t : inputs.swap_blocks) { - proto::CacheBlockInfo cache_block_info; - cache_block_info.set_device_block_id(t.device_block_id); - cache_block_info.set_host_block_id(t.host_block_id); - *pb_forward_input->mutable_swap_blocks()->Add() = cache_block_info; + proto::BlockTransferInfo block_transfer_info; + block_transfer_info.set_src_block_id(t.src_block_id); + block_transfer_info.set_dst_block_id(t.dst_block_id); + *pb_forward_input->mutable_swap_blocks()->Add() = block_transfer_info; } + pb_forward_input->set_batch_id(inputs.batch_id); // block copy kernel ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_src_block_indices(), @@ -741,36 +687,54 @@ Token build_token(int64_t index, return token; } -void proto_to_cache_block_info( - const proto::CacheBlockInfos& cache_block_info_pb, - std::vector& cache_block_info) { - cache_block_info.reserve(cache_block_info_pb.contents_size()); +uint64_t proto_to_block_transfer_info( + const proto::BlockTransferInfos& pb_block_transfer_info, + std::vector& block_transfer_info) { + block_transfer_info.reserve(pb_block_transfer_info.transfer_infos_size()); - for (int i = 0; i < cache_block_info_pb.contents_size(); ++i) { - cache_block_info.emplace_back( - cache_block_info_pb.contents(i).device_block_id(), - cache_block_info_pb.contents(i).host_block_id(), + for (int i = 0; i < pb_block_transfer_info.transfer_infos_size(); ++i) { + block_transfer_info.emplace_back( + pb_block_transfer_info.transfer_infos(i).src_block_id(), + pb_block_transfer_info.transfer_infos(i).dst_block_id(), reinterpret_cast( - cache_block_info_pb.contents(i).hash_key().data())); + pb_block_transfer_info.transfer_infos(i).hash_key().data()), + pb_block_transfer_info.transfer_infos(i).hash_key().size(), + TransferType(pb_block_transfer_info.transfer_type())); } + + return pb_block_transfer_info.batch_id(); } -bool cache_block_info_to_proto( - const std::vector& cache_block_info, - proto::CacheBlockInfos* cache_block_info_pb) { - cache_block_info_pb->mutable_contents()->Reserve(cache_block_info.size()); - for (const CacheBlockInfo block_info : cache_block_info) { - proto::CacheBlockInfo pb_cache; - pb_cache.set_device_block_id(block_info.device_block_id); - pb_cache.set_host_block_id(block_info.host_block_id); - if (block_info.hash_key != nullptr) { - pb_cache.set_hash_key(block_info.hash_key, MURMUR_HASH3_VALUE_LEN); - } else { - LOG(ERROR) << "convert to CacheBlockInfos fail, hash key is nullptr!"; +bool block_transfer_info_to_proto( + const uint64_t batch_id, + const std::vector& block_transfer_info, + proto::BlockTransferInfos* pb_block_transfer_info) { + pb_block_transfer_info->mutable_transfer_infos()->Reserve( + block_transfer_info.size()); + auto transfer_type = block_transfer_info[0].transfer_type; + for (const BlockTransferInfo info : block_transfer_info) { + if (info.hash_key == nullptr) { + LOG(ERROR) << "Convert to BlockTransferInfos fail, hash key is nullptr!"; return false; } - *cache_block_info_pb->mutable_contents()->Add() = pb_cache; + + if (transfer_type != info.transfer_type) { + LOG(ERROR) << "Convert to BlockTransferInfos fail, TransferType must be " + "same, but got " + << uint8_t(transfer_type) << " and " + << uint8_t(info.transfer_type); + return false; + } + + proto::BlockTransferInfo pb_cache; + pb_cache.set_src_block_id(info.src_block_id); + pb_cache.set_dst_block_id(info.dst_block_id); + pb_cache.set_hash_key(info.hash_key, MURMUR_HASH3_VALUE_LEN); + + *pb_block_transfer_info->mutable_transfer_infos()->Add() = pb_cache; } + pb_block_transfer_info->set_batch_id(batch_id); + pb_block_transfer_info->set_transfer_type(proto::TransferType(transfer_type)); return true; } diff --git a/xllm/core/runtime/params_utils.h b/xllm/core/runtime/params_utils.h index 8ebe7357..913e23fb 100644 --- a/xllm/core/runtime/params_utils.h +++ b/xllm/core/runtime/params_utils.h @@ -52,12 +52,13 @@ Token build_token(int64_t index, torch::Tensor top_tokens, torch::Tensor top_logprobs); -void proto_to_cache_block_info( - const proto::CacheBlockInfos& cache_block_info_pb, - std::vector& cache_block_info); +uint64_t proto_to_block_transfer_info( + const proto::BlockTransferInfos& pb_block_transfer_info, + std::vector& block_transfer_info); -bool cache_block_info_to_proto( - const std::vector& cache_block_info, - proto::CacheBlockInfos* cache_block_info_pb); +bool block_transfer_info_to_proto( + const uint64_t batch_id, + const std::vector& block_transfer_info, + proto::BlockTransferInfos* pb_block_transfer_info); } // namespace xllm diff --git a/xllm/core/runtime/worker.cpp b/xllm/core/runtime/worker.cpp index 3ab9b6e6..271df4a0 100644 --- a/xllm/core/runtime/worker.cpp +++ b/xllm/core/runtime/worker.cpp @@ -166,9 +166,20 @@ folly::SemiFuture Worker::pull_kv_blocks_async( dst_blocks); } -folly::SemiFuture Worker::load_kv_blocks_from_store_async( - const std::vector& cache_block_info) { - return impl_->load_kv_blocks_from_store_async(cache_block_info); +uint32_t Worker::transfer_kv_blocks( + const std::vector& block_transfer_info) { + return impl_->transfer_kv_blocks(block_transfer_info); +} + +void Worker::transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info) { + impl_->transfer_kv_blocks(batch_id, std::move(block_transfer_info)); +} + +uint32_t Worker::prefetch_from_storage( + Slice& block_transfer_info) { + return impl_->prefetch_from_storage(block_transfer_info); } const torch::Device& Worker::device() const { return impl_->device(); } diff --git a/xllm/core/runtime/worker.h b/xllm/core/runtime/worker.h index 5b12aaca..eed1c69c 100644 --- a/xllm/core/runtime/worker.h +++ b/xllm/core/runtime/worker.h @@ -105,8 +105,15 @@ class Worker { const std::vector& src_blocks, const std::vector& dst_blocks); - virtual folly::SemiFuture load_kv_blocks_from_store_async( - const std::vector& cache_block_info); + virtual uint32_t transfer_kv_blocks( + const std::vector& block_transfer_info); + + virtual void transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info); + + virtual uint32_t prefetch_from_storage( + Slice& block_transfer_info); // Run the model on the given input. async call // the future returns a successfull status with no meaningful value diff --git a/xllm/core/runtime/worker_client.cpp b/xllm/core/runtime/worker_client.cpp index e23825eb..fe6bb24f 100644 --- a/xllm/core/runtime/worker_client.cpp +++ b/xllm/core/runtime/worker_client.cpp @@ -159,9 +159,24 @@ folly::SemiFuture WorkerClient::pull_kv_blocks_async( dst_blocks); } -folly::SemiFuture WorkerClient::load_kv_blocks_from_store_async( - const std::vector cache_block_info) { - return worker_->load_kv_blocks_from_store_async(cache_block_info); +folly::SemiFuture WorkerClient::transfer_kv_blocks( + const std::vector& block_transfer_info) { + LOG(FATAL) << "WorkerClient Method transfer_kv_blocks with return " + "folly::SemiFuture is " + "UnImplemented."; +} + +void WorkerClient::prefetch_from_storage( + const std::atomic& flag, + const std::vector& block_transfer_info, + std::shared_ptr>& success_cnt) { + LOG(FATAL) << "WorkerClient Method prefetch_from_storage is UnImplemented."; +} + +void WorkerClient::transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info) { + worker_->transfer_kv_blocks(batch_id, block_transfer_info); } const torch::Device& WorkerClient::device() const { return worker_->device(); } diff --git a/xllm/core/runtime/worker_client.h b/xllm/core/runtime/worker_client.h index 0dfe0451..d6294ca3 100644 --- a/xllm/core/runtime/worker_client.h +++ b/xllm/core/runtime/worker_client.h @@ -107,8 +107,17 @@ class WorkerClient { const std::vector& src_blocks, const std::vector& dst_blocks); - virtual folly::SemiFuture load_kv_blocks_from_store_async( - const std::vector cache_block_info); + virtual folly::SemiFuture transfer_kv_blocks( + const std::vector& block_transfer_info); + + virtual void transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info); + + virtual void prefetch_from_storage( + const std::atomic& flag, + const std::vector& block_transfer_info, + std::shared_ptr>& success_cnt); // Run the model on the given input. async call // the future returns a successfull status with no meaningful value diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 3fad5fc7..d6b3740a 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -30,6 +30,7 @@ limitations under the License. #include #include +#include "acl/acl.h" #include "common/device_monitor.h" #include "common/global_flags.h" #include "common/metrics.h" @@ -45,9 +46,14 @@ limitations under the License. #include "util/timer.h" #include "util/utils.h" +#define USE_ASYNC true + namespace xllm { constexpr uint64_t MBUF_SIZE = 128 * 1024 * 1024; +constexpr uint32_t BATCH_COPY_MAX_SIZE = 4096; +constexpr uint32_t TIMEOUT_S = 60; // second +constexpr uint32_t TIMEOUT_MS = 60000; // millisecond WorkerImpl::WorkerImpl(const ParallelArgs& parallel_args, const torch::Device& device, @@ -66,10 +72,26 @@ WorkerImpl::WorkerImpl(const ParallelArgs& parallel_args, device_.set_device(); device_.init_device_context(); - general_threadpool_.schedule([this]() mutable { device_.set_device(); }); + for (int i = 0; i < h2d_threadpool_.size(); i++) { + h2d_threadpool_.schedule_with_tid( + [this]() mutable { + device_.set_device(); + h2d_stream_[std::this_thread::get_id()] = + device_.get_stream_from_pool(TIMEOUT_MS); + }, + i); + } + for (int i = 0; i < d2h_threadpool_.size(); i++) { + d2h_threadpool_.schedule_with_tid( + [this]() mutable { + device_.set_device(); + d2h_stream_[std::this_thread::get_id()] = + device_.get_stream_from_pool(TIMEOUT_MS); + }, + i); + } prepare_stream_ = device_.get_stream_from_pool(); - copy_out_stream_ = device_.get_stream_from_pool(); sampler_ = std::make_unique(); } @@ -101,6 +123,11 @@ bool WorkerImpl::allocate_kv_cache( kv_caches_.emplace_back(key_cache, value_cache); } + key_cache_size_per_layer_ = kv_caches_[0].get_k_cache()[0].numel() * + kv_caches_[0].get_k_cache()[0].element_size(); + value_cache_size_per_layer_ = kv_caches_[0].get_v_cache()[0].numel() * + kv_caches_[0].get_v_cache()[0].element_size(); + allocate_host_kv_cache(kv_cache_shape); status_ = Status::READY; return true; @@ -114,33 +141,50 @@ bool WorkerImpl::allocate_host_kv_cache( CHECK(model_ != nullptr) << "Model is not initialized."; CHECK(host_kv_caches_.empty()) << "KV caches are already initialized."; + CHECK(device_kv_cache_shape[0][0] == device_kv_cache_shape[1][0]); std::vector> host_kv_cache_shape = device_kv_cache_shape; - host_kv_cache_shape[0][0] = + const int64_t num_layers = context_.get_model_args().n_layers(); + int64_t host_bolck_size = device_kv_cache_shape[0][0] * options_.host_blocks_factor(); - host_kv_cache_shape[1][0] = - device_kv_cache_shape[1][0] * options_.host_blocks_factor(); + host_kv_cache_shape[0][0] = num_layers; + host_kv_cache_shape[1][0] = num_layers; - // create a KVCache for each layer - const int64_t num_layers = context_.get_model_args().n_layers(); - kv_caches_.reserve(num_layers); - for (int64_t i = 0; i < num_layers; ++i) { - torch::Tensor key_cache, value_cache; - key_cache = torch::empty(host_kv_cache_shape[0], - torch::dtype(dtype_).device(torch::kCPU)); - value_cache = torch::empty(host_kv_cache_shape[1], - torch::dtype(dtype_).device(torch::kCPU)); - host_kv_caches_.emplace_back(key_cache, value_cache); - } + // create a KVCache shape: block_size * [layers, token, head, dim] + aligned_tensor_creater_ = std::make_unique( + host_kv_cache_shape, dtype_, host_bolck_size, &host_kv_caches_); + + LOG(INFO) << "Initializing host kv block size: " << host_bolck_size; + + int32_t device_id = device_.index(); + h2d_attrs_.dstLoc.id = device_id; + h2d_attrs_.dstLoc.type = aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE; + h2d_attrs_.srcLoc.id = device_id; + h2d_attrs_.srcLoc.type = aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST; + memset(h2d_attrs_.rsv, 0, 16); + + d2h_attrs_.dstLoc.id = device_id; + d2h_attrs_.dstLoc.type = aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST; + d2h_attrs_.srcLoc.id = device_id; + d2h_attrs_.srcLoc.type = aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE; + memset(d2h_attrs_.rsv, 0, 16); if (options_.enable_kvcache_store()) { StoreConfig config; + config.localhost_name = options_.store_local_hostname(); config.protocol = options_.store_protocol(); - config.metadata_connstring = options_.store_metadata_connstring(); - config.master_server_entry = options_.store_master_server_entry(); - config.tp_rank = options_.node_rank() % options_.dp_size(); - - kv_cache_store_ = std::make_shared(config, &host_kv_caches_); + config.metadata_server = options_.store_metadata_server(); + config.master_server_address = options_.store_master_server_address(); + config.tp_rank = options_.dp_size() > 1 + ? options_.node_rank() % options_.dp_size() + : options_.node_rank(); + config.total_size = aligned_tensor_creater_->get_total_size(); + config.tensor_data = aligned_tensor_creater_->get_base_ptr(); + + if (!KVCacheStore::get_instance().init(config, &host_kv_caches_)) { + LOG(ERROR) << "Init KVCacheStore fail!"; + return false; + } } status_ = Status::READY; @@ -375,31 +419,9 @@ void WorkerImpl::prepare_work_before_execute( fwd_inputs_on_device = inputs.micro_inputs[i].to(device_, dtype_); auto& input_params = fwd_inputs_on_device.input_params; #if defined(USE_NPU) - if (input_params.copy_out_blocks.size() > 0 || - input_params.copy_in_blocks.size() > 0) { - const int64_t num_layers = context_.get_model_args().n_layers(); - for (int layer_id = 0; layer_id < num_layers; layer_id++) { - auto key_cache = kv_caches_[layer_id].get_k_cache(); - auto host_k_cache = host_kv_caches_[layer_id].get_k_cache(); - auto value_cache = kv_caches_[layer_id].get_v_cache(); - auto host_v_cache = host_kv_caches_[layer_id].get_v_cache(); - - for (auto block_info : input_params.copy_out_blocks) { - host_k_cache[block_info.host_block_id].copy_( - key_cache[block_info.device_block_id]); - host_v_cache[block_info.host_block_id].copy_( - value_cache[block_info.device_block_id]); - } - for (auto block_info : input_params.copy_in_blocks) { - key_cache[block_info.device_block_id].copy_( - host_k_cache[block_info.host_block_id]); - value_cache[block_info.device_block_id].copy_( - host_v_cache[block_info.host_block_id]); - } - } - - offload_kv_blocks_to_store_async( - inputs.micro_inputs[i].input_params.copy_out_blocks); + if (input_params.swap_blocks.size() > 0 && + !FLAGS_enable_block_copy_kernel) { + auto& swap_blocks = input_params.swap_blocks; if (input_params.swap_blocks.size() > 0 && !FLAGS_enable_block_copy_kernel) { @@ -411,8 +433,8 @@ void WorkerImpl::prepare_work_before_execute( dst_indices.reserve(swap_blocks.size()); for (const auto& block : swap_blocks) { - src_indices.push_back(block.device_block_id); - dst_indices.push_back(block.host_block_id); + src_indices.push_back(block.src_block_id); + dst_indices.push_back(block.dst_block_id); } // batch select keys and values @@ -461,39 +483,6 @@ void WorkerImpl::prepare_work_before_execute( auto ret = prepare_stream_->synchronize(); } -folly::SemiFuture WorkerImpl::copy_out_blocks_async( - ModelInputParams& input_params) { - folly::Promise promise; - auto future = promise.getSemiFuture(); - general_threadpool_.schedule([this, - input_params = input_params, - promise = std::move(promise)]() mutable { - c10::StreamGuard streamGuard = copy_out_stream_->set_stream_guard(); - if (input_params.async_copy_out_blocks.size() > 0) { - const int64_t num_layers = context_.get_model_args().n_layers(); - for (int layer_id = 0; layer_id < num_layers; layer_id++) { - auto key_cache = kv_caches_[layer_id].get_k_cache(); - auto host_k_cache = host_kv_caches_[layer_id].get_k_cache(); - auto value_cache = kv_caches_[layer_id].get_v_cache(); - auto host_v_cache = host_kv_caches_[layer_id].get_v_cache(); - - for (auto block_info : input_params.async_copy_out_blocks) { - host_k_cache[block_info.host_block_id].copy_( - key_cache[block_info.device_block_id]); - host_v_cache[block_info.host_block_id].copy_( - value_cache[block_info.device_block_id]); - } - } - - offload_kv_blocks_to_store(input_params.async_copy_out_blocks); - } - auto ret = copy_out_stream_->synchronize(); - promise.setValue(ret == 0); - }); - - return future; -} - folly::SemiFuture> WorkerImpl::step_async( const BatchedForwardInputs& inputs) { BatchedForwardInputs batched_inputs_on_device; @@ -506,19 +495,19 @@ folly::SemiFuture> WorkerImpl::step_async( threadpool_.schedule([this, inputs = std::move(batched_inputs_on_device), promise = std::move(promise)]() mutable { - // run the model on the given input in working thread - std::vector> copy_futures; for (auto& input : inputs.micro_inputs) { - copy_futures.push_back( - std::move(copy_out_blocks_async(input.input_params))); + { + std::lock_guard lock(mutex_); + if (layer_wise_load_synchronizer_.count(input.input_params.batch_id) != + 0) { + input.input_params.layer_wise_load_synchronizer = + layer_wise_load_synchronizer_[input.input_params.batch_id]; + } + } } + // run the model on the given input in working thread if (!enable_schedule_overlap()) { const auto output = this->step(inputs); - std::for_each(copy_futures.begin(), - copy_futures.end(), - [](folly::SemiFuture& copy_future) { - std::move(copy_future).get(); - }); promise.setValue(output); } else { for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { @@ -551,11 +540,6 @@ folly::SemiFuture> WorkerImpl::step_async( last_step_output_valid_ = false; } } - std::for_each(copy_futures.begin(), - copy_futures.end(), - [](folly::SemiFuture& copy_future) { - std::move(copy_future).get(); - }); promise.setValue(output); } }); @@ -698,42 +682,40 @@ folly::SemiFuture WorkerImpl::pull_kv_blocks_async( return false; } -folly::SemiFuture WorkerImpl::load_kv_blocks_from_store_async( - const std::vector& cache_block_info) { - folly::Promise promise; - auto future = promise.getSemiFuture(); - general_threadpool_.schedule( - [this, &cache_block_info, promise = std::move(promise)]() mutable { - if (this->kv_cache_store_ == nullptr) { - promise.setValue(0); - return; - } - promise.setValue(this->kv_cache_store_->batch_get(cache_block_info)); - }); - return future; -} - -uint32_t WorkerImpl::offload_kv_blocks_to_store( - const std::vector& cache_block_info) { - if (kv_cache_store_ == nullptr) { - return 0; +uint32_t WorkerImpl::transfer_kv_blocks( + const std::vector& block_transfer_info) { + CHECK(!block_transfer_info.empty()); + + switch (block_transfer_info[0].transfer_type) { + case TransferType::D2G: + return offload_kv_blocks(block_transfer_info); + default: + LOG(ERROR) << "Unsupport copy type: " + << uint32_t(block_transfer_info[0].transfer_type); + return 0; } - return kv_cache_store_->batch_put(cache_block_info); } -folly::SemiFuture WorkerImpl::offload_kv_blocks_to_store_async( - const std::vector& cache_block_info) { - folly::Promise promise; - auto future = promise.getSemiFuture(); - general_threadpool_.schedule( - [this, &cache_block_info, promise = std::move(promise)]() mutable { - if (this->kv_cache_store_ == nullptr) { - promise.setValue(0); - return; +void WorkerImpl::transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info) { + CHECK(!block_transfer_info.empty()); + h2d_threadpool_.schedule( + [this, + batch_id = batch_id, + block_transfer_info = std::move(block_transfer_info)]() mutable { + switch (block_transfer_info[0].transfer_type) { + case TransferType::H2D: { + Slice info_slice{block_transfer_info}; + h2d_batch_copy(batch_id, info_slice); + break; + } + default: + LOG(ERROR) << "Unsupport copy type: " + << uint32_t(block_transfer_info[0].transfer_type); + break; } - promise.setValue(this->kv_cache_store_->batch_put(cache_block_info)); }); - return future; } folly::SemiFuture WorkerImpl::allocate_kv_cache_with_transfer_async( @@ -758,4 +740,346 @@ int64_t WorkerImpl::get_active_activation_memory() { .active_activation_memory; } +uint32_t WorkerImpl::offload_kv_blocks( + const std::vector& block_transfer_info) { + if (block_transfer_info.empty()) { + return 0; + } + + const int64_t num_layers = context_.get_model_args().n_layers(); + uint32_t max_blocks_per_batch = BATCH_COPY_MAX_SIZE / (2 * num_layers); + uint32_t total_slice = + block_transfer_info.size() / max_blocks_per_batch + + uint32_t(block_transfer_info.size() % max_blocks_per_batch != 0); + + Slice transfer_info_slice(block_transfer_info); + std::vector> futures; + futures.reserve(total_slice); + + for (size_t i = 0; i < block_transfer_info.size(); + i += max_blocks_per_batch) { + folly::Promise promise; + auto future = promise.getSemiFuture(); + auto slice = transfer_info_slice.slice( + i, std::min(i + max_blocks_per_batch, block_transfer_info.size())); + + d2h_threadpool_.schedule([this, + promise = std::move(promise), + slice = std::move(slice)]() mutable { + bool ret = d2h_batch_copy(slice); + auto success_cnt = offload_to_store(slice); + if (success_cnt != slice.size()) { + LOG(WARNING) << "KVCacheStore not all put success: " << success_cnt + << "/" << slice.size(); + } + promise.setValue(ret); + }); + + futures.emplace_back(std::move(future)); + } + + if (!futures.empty()) { + try { + // TODO(kangmeng): add timeout + auto all_results = folly::collect(futures).get(); + if (!std::all_of(all_results.begin(), all_results.end(), [](bool result) { + return result; + })) { + LOG(FATAL) << "Not all D2H copy returned true"; + } + } catch (const std::exception& e) { + LOG(FATAL) << "Future execution failed: " << e.what(); + } + } + + return block_transfer_info.size(); +} + +bool WorkerImpl::d2h_batch_copy(Slice& block_transfer_info) { +#if defined(USE_NPU) + CHECK(d2h_stream_.count(std::this_thread::get_id()) != 0) + << "WorkerImpl::d2h_batch_copy can only be called in d2h_threadpool_."; + + const int64_t num_layers = context_.get_model_args().n_layers(); + uint32_t num_batches = block_transfer_info.size() * num_layers * 2; + void** srcs = new void*[num_batches]; + void** dsts = new void*[num_batches]; + size_t* copy_size = new size_t[num_batches]; + aclrtMemcpyBatchAttr attrs[1] = {d2h_attrs_}; + size_t attrs_indexes[1] = {0}; + size_t fail_index; + uint32_t curr_index = 0; + + for (const auto& info : block_transfer_info) { + auto dst_k_cache = host_kv_caches_.at(info.dst_block_id).get_k_cache(); + auto dst_v_cache = host_kv_caches_.at(info.dst_block_id).get_v_cache(); + + for (int layer_id = 0; layer_id < num_layers; layer_id++) { + auto src_k_cache = kv_caches_.at(layer_id).get_k_cache(); + auto src_v_cache = kv_caches_.at(layer_id).get_v_cache(); + + srcs[curr_index] = src_k_cache[info.src_block_id].data_ptr(); + dsts[curr_index] = dst_k_cache[layer_id].data_ptr(); + copy_size[curr_index] = key_cache_size_per_layer_; + + curr_index++; + + srcs[curr_index] = src_v_cache[info.src_block_id].data_ptr(); + dsts[curr_index] = dst_v_cache[layer_id].data_ptr(); + copy_size[curr_index] = value_cache_size_per_layer_; + + curr_index++; + } + } + + c10::StreamGuard streamGuard = + d2h_stream_[std::this_thread::get_id()]->set_stream_guard(); + + // TODO(kangmeng): change to async API + aclError ret = aclrtMemcpyBatch(dsts, + copy_size, + srcs, + copy_size, + num_batches, + attrs, + attrs_indexes, + 1, + &fail_index); + if (ret != 0 || fail_index != SIZE_MAX) { + LOG(ERROR) << "aclrtMemcpyBatch error: " << ret + << ", fail_index:" << fail_index; + return false; + } + + if (d2h_stream_[std::this_thread::get_id()]->synchronize() != 0) { + LOG(ERROR) << "d2h_batch_copy timeout!"; + return false; + } + + delete[] dsts; + delete[] srcs; + delete[] copy_size; + +#endif + return true; +} + +bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id, + Slice& block_transfer_info) { +#if defined(USE_NPU) + CHECK(h2d_stream_.count(std::this_thread::get_id()) != 0) + << "WorkerImpl::h2d_batch_copy can only be called in h2d_threadpool_."; + CHECK(block_transfer_info.size() < BATCH_COPY_MAX_SIZE / 2) + << "h2d_batch_copy support copy blocks less than " + << BATCH_COPY_MAX_SIZE / 2 << ", but got " << block_transfer_info.size(); + + if (block_transfer_info.empty()) { + return true; + } + + const int64_t num_layers = context_.get_model_args().n_layers(); + uint32_t num_batches = block_transfer_info.size() * 2; + + auto synchronizer = std::make_shared(num_layers); + { + std::lock_guard lock(mutex_); + if (layer_wise_load_synchronizer_.count(batch_id) != 0) { + LOG(FATAL) << "Batch id already exists!"; + } + layer_wise_load_synchronizer_[batch_id] = synchronizer; + } + + void** srcs = new void*[num_batches]; + void** dsts = new void*[num_batches]; + size_t* copy_size = new size_t[num_batches]; + aclrtMemcpyBatchAttr attrs[1] = {h2d_attrs_}; + size_t attrs_indexes[1] = {0}; + + c10::StreamGuard streamGuard = + h2d_stream_[std::this_thread::get_id()]->set_stream_guard(); + auto stream = h2d_stream_[std::this_thread::get_id()]->get_stream()->stream(); + aclError ret = 0; + + for (int layer_id = 0; layer_id < num_layers; layer_id++) { + auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache(); + auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache(); + size_t fail_index = 0; + uint32_t curr_index = 0; + auto* event = synchronizer->get_event(layer_id); + auto* event_flag = synchronizer->get_event_flag(layer_id); + + for (const auto& info : block_transfer_info) { + auto src_k_cache = host_kv_caches_.at(info.src_block_id).get_k_cache(); + auto src_v_cache = host_kv_caches_.at(info.src_block_id).get_v_cache(); + + srcs[curr_index] = src_k_cache[layer_id].data_ptr(); + dsts[curr_index] = dst_k_cache[info.dst_block_id].data_ptr(); + copy_size[curr_index] = key_cache_size_per_layer_; + curr_index++; + + srcs[curr_index] = src_v_cache[layer_id].data_ptr(); + dsts[curr_index] = dst_v_cache[info.dst_block_id].data_ptr(); + copy_size[curr_index] = value_cache_size_per_layer_; + curr_index++; + } + + // TODO(kangmeng): change to async API + ret = aclrtMemcpyBatch(dsts, + copy_size, + srcs, + copy_size, + num_batches, + attrs, + attrs_indexes, + 1, + &fail_index); + + if (ret != 0 || fail_index != SIZE_MAX) { + LOG(ERROR) << "aclrtMemcpyBatch error: " << ret + << ", fail_index:" << fail_index; + } else { + ret = aclrtRecordEvent(*event, stream); + if (ret != 0) { + LOG(ERROR) << "aclrtRecordEvent error: " << ret; + } + } + event_flag->store(true, std::memory_order_release); + if (ret != 0) break; + } + + if (h2d_stream_[std::this_thread::get_id()]->synchronize() != 0) { + LOG(ERROR) << "h2d_batch_copy timeout!"; + return false; + } + + delete[] dsts; + delete[] srcs; + delete[] copy_size; + +#endif + return true; +} + +uint32_t WorkerImpl::offload_to_store( + Slice& block_transfer_info) { + if (!options_.enable_kvcache_store()) { + return block_transfer_info.size(); + } + + folly::Promise promise; + auto future = promise.getSemiFuture(); + + batchput_threadpool_.schedule( + [this, &block_transfer_info, promise = std::move(promise)]() mutable { + promise.setValue( + KVCacheStore::get_instance().batch_put(block_transfer_info)); + }); + + return std::move(future) + .via(folly::getGlobalCPUExecutor()) + .within(std::chrono::seconds(TIMEOUT_S)) + .thenTry([](folly::Try&& t) -> uint32_t { + if (t.hasValue()) { + return t.value(); + } else { + LOG(WARNING) << "BatchPut operation timed out"; + return 0u; + } + }) + .get(); +} + +uint32_t WorkerImpl::prefetch_from_storage( + Slice& block_transfer_info) { + if (!options_.enable_kvcache_store()) { + return 0; + } + + folly::Promise promise; + auto future = promise.getSemiFuture(); + + batchget_threadpool_.schedule( + [this, &block_transfer_info, promise = std::move(promise)]() mutable { + promise.setValue( + KVCacheStore::get_instance().batch_get(block_transfer_info)); + }); + + return std::move(future) + .via(folly::getGlobalCPUExecutor()) + .within(std::chrono::seconds(TIMEOUT_S)) + .thenTry([](folly::Try&& t) -> uint32_t { + if (t.hasValue()) { + return t.value(); + } else { + LOG(WARNING) << "BatchGet operation timed out"; + return 0u; + } + }) + .get(); +} + +AlignedTensorCreater::AlignedTensorCreater( + const std::vector>& tensor_shapes, + const torch::ScalarType dtype, + const uint32_t num_tensors, + std::vector* tensors) { + CHECK(tensor_shapes.size() == 2) + << "tensor_shapes.size() must equal to 2, but got " + << tensor_shapes.size(); + + int64_t elements_per_k_tensor = 1; + int64_t elements_per_v_tensor = 1; + + for (auto dim : tensor_shapes[0]) { + elements_per_k_tensor *= dim; + } + for (auto dim : tensor_shapes[1]) { + elements_per_v_tensor *= dim; + } + + size_t element_size = torch::elementSize(dtype); + size_t bytes_per_k_tensor = elements_per_k_tensor * element_size; + size_t bytes_per_v_tensor = elements_per_v_tensor * element_size; + size_t page_size = sysconf(_SC_PAGESIZE); + total_size_ = num_tensors * (bytes_per_k_tensor + bytes_per_v_tensor); + total_size_ = ((total_size_ + page_size - 1) / page_size) * page_size; + + base_ptr_ = mmap(nullptr, + total_size_, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0); + + if (base_ptr_ == MAP_FAILED) { + LOG(FATAL) << "Failed to allocate aligned memory pool!"; + } + + if (mlock(base_ptr_, total_size_) != 0) { + munmap(base_ptr_, total_size_); + LOG(FATAL) << "Failed to lock memory pool!"; + } + + size_t current_offset = 0; + auto options = torch::TensorOptions().dtype(dtype).device(torch::kCPU); + tensors->reserve(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + void* k_tensor_ptr = static_cast(base_ptr_) + current_offset; + torch::Tensor k_tensor = + torch::from_blob(k_tensor_ptr, tensor_shapes[0], options); + current_offset += bytes_per_k_tensor; + + void* v_tensor_ptr = static_cast(base_ptr_) + current_offset; + torch::Tensor v_tensor = + torch::from_blob(v_tensor_ptr, tensor_shapes[1], options); + current_offset += bytes_per_v_tensor; + + tensors->emplace_back(k_tensor, v_tensor); + } + + LOG(INFO) << "Page aligned: " + << ((uintptr_t)base_ptr_ % page_size == 0 ? "YES" : "NO"); +} + } // namespace xllm diff --git a/xllm/core/runtime/worker_impl.h b/xllm/core/runtime/worker_impl.h index 63b1560e..3bbd7156 100644 --- a/xllm/core/runtime/worker_impl.h +++ b/xllm/core/runtime/worker_impl.h @@ -16,6 +16,7 @@ limitations under the License. #pragma once #include +#include #include #include @@ -45,6 +46,8 @@ limitations under the License. namespace xllm { +class AlignedTensorCreater; + class WorkerImpl { public: enum Status : int8_t { @@ -146,14 +149,12 @@ class WorkerImpl { const std::vector& src_blocks, const std::vector& dst_blocks); - virtual folly::SemiFuture load_kv_blocks_from_store_async( - const std::vector& cache_block_info); - - virtual uint32_t offload_kv_blocks_to_store( - const std::vector& cache_block_info); + virtual uint32_t transfer_kv_blocks( + const std::vector& block_transfer_info); - virtual folly::SemiFuture offload_kv_blocks_to_store_async( - const std::vector& cache_block_info); + virtual void transfer_kv_blocks( + const uint64_t batch_id, + const std::vector& block_transfer_info); // Run the model on the given input. async call // the future returns a successfull status with no meaningful value @@ -182,11 +183,21 @@ class WorkerImpl { Status get_status() const { return status_; } - folly::SemiFuture copy_out_blocks_async(ModelInputParams& input_params); + virtual uint32_t prefetch_from_storage( + Slice& block_transfer_info); private: void update_last_step_output(const std::optional& output); + uint32_t offload_kv_blocks( + const std::vector& block_transfer_info); + + bool d2h_batch_copy(Slice& block_transfer_info); + bool h2d_batch_copy(const uint64_t batch_id, + Slice& block_transfer_info); + + uint32_t offload_to_store(Slice& block_transfer_info); + protected: // runtime options runtime::Options options_; @@ -201,9 +212,16 @@ class WorkerImpl { // the task queue, step need to be executed one-by-one ThreadPool threadpool_; - // general working thread - // do some overlap work with model execute - ThreadPool general_threadpool_; + // working thread for data copy + ThreadPool h2d_threadpool_{2}; + ThreadPool d2h_threadpool_{5}; + ThreadPool batchget_threadpool_{5}; + ThreadPool batchput_threadpool_{2}; + // copy streams + // only can be used in h2d_threadpool_ + std::unordered_map> h2d_stream_; + // only can be used in d2h_threadpool_ + std::unordered_map> d2h_stream_; // dtype of the model torch::ScalarType dtype_; @@ -212,7 +230,6 @@ class WorkerImpl { Device device_; std::unique_ptr prepare_stream_; - std::unique_ptr copy_out_stream_; // parallel args of current instance ParallelArgs parallel_args_; @@ -223,6 +240,7 @@ class WorkerImpl { // kv caches std::vector kv_caches_; std::vector host_kv_caches_; + std::unique_ptr aligned_tensor_creater_; // causal LM model std::unique_ptr model_; @@ -245,15 +263,44 @@ class WorkerImpl { #if defined(USE_NPU) std::shared_ptr kv_cache_transfer_; + aclrtMemcpyBatchAttr h2d_attrs_; + aclrtMemcpyBatchAttr d2h_attrs_; #endif - std::shared_ptr kv_cache_store_; + uint64_t key_cache_size_per_layer_; + uint64_t value_cache_size_per_layer_; bool is_spec_draft_ = false; Status status_ = Status::UNINITIALIZED; torch::Tensor expert_load_data_; + + mutable std::mutex mutex_; + std::unordered_map> + layer_wise_load_synchronizer_; +}; + +class AlignedTensorCreater { + private: + void* base_ptr_; + size_t total_size_; + + public: + AlignedTensorCreater(const std::vector>& tensor_shapes, + const torch::ScalarType dtype, + const uint32_t num_tensors, + std::vector* tensors); + + ~AlignedTensorCreater() { + if (base_ptr_ != nullptr) { + munlock(base_ptr_, total_size_); + munmap(base_ptr_, total_size_); + } + } + + void* get_base_ptr() const { return base_ptr_; } + size_t get_total_size() const { return total_size_; } }; } // namespace xllm diff --git a/xllm/core/scheduler/continuous_scheduler.cpp b/xllm/core/scheduler/continuous_scheduler.cpp index 5b637a91..30dea16b 100644 --- a/xllm/core/scheduler/continuous_scheduler.cpp +++ b/xllm/core/scheduler/continuous_scheduler.cpp @@ -95,7 +95,7 @@ bool ContinuousScheduler::add_request(std::shared_ptr& request) { CHECK(request != nullptr); CHECK(!request->sequences().empty()); - prepare_cache_async(request); + prefetch_from_storage(request); if (request_queue_.write(request)) { return true; @@ -223,7 +223,7 @@ void ContinuousScheduler::handle_prefill_requests( continue; } - prefill_sequence->sync_result(); + prefill_sequence->update_prefetch_result(); // FIXME: use actual num_tokens to handle // Currently overestimating the number of tokens actually processed when // enable prefix cache @@ -254,11 +254,11 @@ void ContinuousScheduler::handle_prefill_requests( std::shared_ptr request_to_preempt = running_queue_offline_->back(); ++num_online_prefill_preempt_offline_requests; + request_to_preempt->set_preempted(); kv_cache_manager_->deallocate(request_to_preempt.get()); running_queue_offline_->pop_back(); // add preemptable request to waiting priority queue // TO IMPROVE?: not process this offline request in current batch - request_to_preempt->set_preempted(); waiting_priority_queue_offline_.push(request_to_preempt); } if (!kv_cache_manager_->allocate(prefill_sequence.get())) { @@ -475,20 +475,20 @@ void ContinuousScheduler::handle_decode_requests( std::shared_ptr request_to_preempt = running_queue_offline_->back(); ++num_online_decode_preempt_offline_requests; + request_to_preempt->set_preempted(); kv_cache_manager_->deallocate(request_to_preempt.get()); running_queue_offline_->pop_back(); // add preemptable request to waiting priority queue - request_to_preempt->set_preempted(); waiting_priority_queue_offline_.push(request_to_preempt); continue; } else if (running_queue->size() > 1) { std::shared_ptr request_to_preempt = running_queue->back(); if (request_to_preempt.get() != request.get()) { // TO IMPROVE: kv cache offload to cpu + request_to_preempt->set_preempted(); kv_cache_manager_->deallocate(request_to_preempt.get()); running_queue->pop_back(); // add preemptable request to waiting priority queue - request_to_preempt->set_preempted(); if (request_to_preempt->offline()) { ++num_offline_decode_preempt_offline_requests; waiting_priority_queue_offline_.push(request_to_preempt); @@ -795,13 +795,44 @@ std::vector ContinuousScheduler::prepare_batch() { ->create_batches(running_requests_, running_sequences_, running_sequences_budgets_, - kv_cache_manager_->get_copy_in_cache_block_infos(), - kv_cache_manager_->get_copy_out_cache_block_infos(), - kv_cache_manager_->get_swap_cache_block_infos()); + kv_cache_manager_->get_swap_block_transfer_infos()); if (!batches[0].empty()) { // only update the scheduling latency when there are requests to process COUNTER_ADD(scheduling_latency_seconds, timer.elapsed_seconds()); + + auto* load_block_transfer_infos = + kv_cache_manager_->get_load_block_transfer_infos(); + + for (int i = 0; i < batches.size(); i++) { + if (!load_block_transfer_infos->at(i).empty()) { + batches[i].set_batch_id(); + engine_->transfer_kv_blocks( + i, + batches[i].batch_id(), + std::move(load_block_transfer_infos->at(i))); + } + } + } + + auto* offload_block_transfer_infos = + kv_cache_manager_->get_offload_block_transfer_infos(); + + bool is_all_dp_copy_info_empty = true; + std::vector>> futures; + futures.resize(offload_block_transfer_infos->size()); + + for (int i = 0; i < futures.size(); i++) { + if (!offload_block_transfer_infos->at(i).empty()) { + futures[i] = std::move(engine_->transfer_kv_blocks( + i, std::move(offload_block_transfer_infos->at(i)))); + + is_all_dp_copy_info_empty = false; + } + } + + if (!is_all_dp_copy_info_empty) { + kv_cache_manager_->set_offload_callback(futures); } GAUGE_SET(num_pending_requests, @@ -867,10 +898,12 @@ std::vector ContinuousScheduler::schedule_request( return batch; } -void ContinuousScheduler::prepare_cache_async( +void ContinuousScheduler::prefetch_from_storage( std::shared_ptr& request) { if (request->sequences()[0]->kv_state().num_kv_blocks() != 0 || request->sequences()[0]->host_kv_state().num_kv_blocks() != 0) { + LOG(ERROR) + << "prefetch_from_storage can only be called before prepare batch!"; return; } for (auto& prefill_sequence : request->sequences()) { @@ -878,18 +911,22 @@ void ContinuousScheduler::prepare_cache_async( kv_cache_manager_->pre_allocate(prefill_sequence.get()); if (num_additional_blocks > 0) { const auto host_blocks = prefill_sequence->host_kv_state().kv_blocks(); - std::vector contents; - contents.reserve(num_additional_blocks); + std::vector block_transfer_infos; + block_transfer_infos.reserve(num_additional_blocks); for (int i = host_blocks.size() - num_additional_blocks; i < host_blocks.size(); i++) { - contents.emplace_back( - -1, host_blocks[i].id(), host_blocks[i].get_immutable_hash_value()); + block_transfer_infos.emplace_back( + BlockTransferInfo(-1, + host_blocks[i].id(), + host_blocks[i].get_immutable_hash_value(), + TransferType::G2H)); } - auto futures = engine_->load_kv_blocks_from_store_async( - prefill_sequence->dp_rank(), std::move(contents)); - prefill_sequence->set_async_result(std::move(futures)); + engine_->prefetch_from_storage(prefill_sequence->dp_rank(), + prefill_sequence->get_termination_flag(), + std::move(block_transfer_infos), + prefill_sequence->get_prefetch_results()); } } } @@ -915,7 +952,7 @@ void ContinuousScheduler::step(const absl::Duration& timeout) { step_with_pd_ooc(batch); } - kv_cache_manager_->reset_copy_content(); + kv_cache_manager_->reset_transfer_infos(); // process request output in batch process_batch_output(false); } else { @@ -941,7 +978,7 @@ void ContinuousScheduler::step_with_schedule_overlap( if (!cur_batch_all_empty) { engine_->step(batch); - kv_cache_manager_->reset_copy_content(); + kv_cache_manager_->reset_transfer_infos(); } // producer-consumer mode, make sure only one step is scheduled in advance @@ -971,7 +1008,7 @@ void ContinuousScheduler::generate() { // run inference for the batch engine_->step(batch); - kv_cache_manager_->reset_copy_content(); + kv_cache_manager_->reset_transfer_infos(); // process request output in batch process_batch_output(false); diff --git a/xllm/core/scheduler/continuous_scheduler.h b/xllm/core/scheduler/continuous_scheduler.h index d2918f74..cbf6be33 100644 --- a/xllm/core/scheduler/continuous_scheduler.h +++ b/xllm/core/scheduler/continuous_scheduler.h @@ -264,7 +264,7 @@ class ContinuousScheduler : public Scheduler { size_t& num_online_decode_preempt_offline_requests, std::unique_ptr& running_queue); - virtual void prepare_cache_async(std::shared_ptr& request); + virtual void prefetch_from_storage(std::shared_ptr& request); void handle_abnormal_request( std::unique_ptr& running_queue, diff --git a/xllm/core/scheduler/pd_ooc_scheduler.cpp b/xllm/core/scheduler/pd_ooc_scheduler.cpp index c8cf4558..33dc3db5 100644 --- a/xllm/core/scheduler/pd_ooc_scheduler.cpp +++ b/xllm/core/scheduler/pd_ooc_scheduler.cpp @@ -367,13 +367,10 @@ std::vector PDOOCScheduler::prepare_batch() { response_processor_->process_completed_requests(finished_requests); } - auto batches = - BatchFactory::get_instance(options_.dp_size()) - ->create_batches(running_requests_, - running_sequences_, - running_sequences_budgets_, - kv_cache_manager_->get_copy_in_cache_block_infos(), - kv_cache_manager_->get_copy_out_cache_block_infos()); + auto batches = BatchFactory::get_instance(options_.dp_size()) + ->create_batches(running_requests_, + running_sequences_, + running_sequences_budgets_); if (!batches[0].empty()) { // only update the scheduling latency when there are requests to process diff --git a/xllm/core/scheduler/prefill_only_scheduler.cpp b/xllm/core/scheduler/prefill_only_scheduler.cpp index b1b3d0d0..664ca639 100644 --- a/xllm/core/scheduler/prefill_only_scheduler.cpp +++ b/xllm/core/scheduler/prefill_only_scheduler.cpp @@ -102,7 +102,7 @@ void PrefillOnlyScheduler::handle_prefill_requests( continue; } - prefill_sequence->sync_result(); + prefill_sequence->update_prefetch_result(); // FIXME: use actual num_tokens to handle // Currently overestimating the number of tokens actually processed when // enable prefix cache @@ -291,7 +291,7 @@ void PrefillOnlyScheduler::handle_last_step_prefill_requests( continue; } - prefill_sequence->sync_result(); + prefill_sequence->update_prefetch_result(); // FIXME: use actual num_tokens to handle // Currently overestimating the number of tokens actually processed when // enable prefix cache @@ -629,9 +629,7 @@ std::vector PrefillOnlyScheduler::prepare_batch() { ->create_batches(running_requests_, running_sequences_, running_sequences_budgets_, - kv_cache_manager_->get_copy_in_cache_block_infos(), - kv_cache_manager_->get_copy_out_cache_block_infos(), - kv_cache_manager_->get_swap_cache_block_infos()); + kv_cache_manager_->get_swap_block_transfer_infos()); if (!batches[0].empty()) { // only update the scheduling latency when there are requests to process diff --git a/xllm/core/scheduler/profile/profile_manager.cpp b/xllm/core/scheduler/profile/profile_manager.cpp index 72b3df6a..aeb2c43f 100644 --- a/xllm/core/scheduler/profile/profile_manager.cpp +++ b/xllm/core/scheduler/profile/profile_manager.cpp @@ -546,10 +546,8 @@ double ProfileManager::run_request(int32_t token_length, sequences_budget.emplace_back(token_length - prefix_length); } // build batch - auto batches = - BatchFactory::get_instance(options_.dp_size()) - ->create_batches( - requests, sequences, sequences_budget, nullptr, nullptr); + auto batches = BatchFactory::get_instance(options_.dp_size()) + ->create_batches(requests, sequences, sequences_budget); absl::Time start_time = absl::Now(); engine_->step(batches); @@ -588,8 +586,7 @@ double ProfileManager::run_request( // build batch auto batches = BatchFactory::get_instance(options_.dp_size()) - ->create_batches( - requests, sequences, sequences_budget, nullptr, nullptr); + ->create_batches(requests, sequences, sequences_budget, nullptr); absl::Time start_time = absl::Now(); engine_->step(batches); diff --git a/xllm/core/util/hash_util.h b/xllm/core/util/hash_util.h index 31393d5b..ecfdde5c 100644 --- a/xllm/core/util/hash_util.h +++ b/xllm/core/util/hash_util.h @@ -38,6 +38,11 @@ struct Murmur3Key { reinterpret_cast(other.data), MURMUR_HASH3_VALUE_LEN); } + + void set(const uint8_t* const input_data) { + std::memcpy(data, input_data, MURMUR_HASH3_VALUE_LEN); + } + std::string debug_string() { std::string rt; for (int i = 0; i < MURMUR_HASH3_VALUE_LEN; i++) { diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 010993a4..174dde77 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -205,6 +205,12 @@ class DeepseekV2ModelImpl : public torch::nn::Module { event_flags[j] = input_params[j].layer_synchronizer->get_event_flag(i); } + if (input_params[j].layer_wise_load_synchronizer != nullptr) { + if (!input_params[j].layer_wise_load_synchronizer->synchronize_layer( + i)) { + return torch::Tensor(); + } + } } auto& layer = layers_[i]; layer(hs, diff --git a/xllm/models/llm/deepseek_v2_mtp.h b/xllm/models/llm/deepseek_v2_mtp.h index 7960711c..b1c27c14 100644 --- a/xllm/models/llm/deepseek_v2_mtp.h +++ b/xllm/models/llm/deepseek_v2_mtp.h @@ -161,6 +161,12 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module { event_flags[j] = input_params[j].layer_synchronizer->get_event_flag(i); } + if (input_params[j].layer_wise_load_synchronizer != nullptr) { + if (!input_params[j].layer_wise_load_synchronizer->synchronize_layer( + i)) { + return torch::Tensor(); + } + } } auto& layer = layers_[i]; layer(hs, diff --git a/xllm/models/llm/glm4_moe.h b/xllm/models/llm/glm4_moe.h index 79dbefd7..41a5cddc 100644 --- a/xllm/models/llm/glm4_moe.h +++ b/xllm/models/llm/glm4_moe.h @@ -177,6 +177,11 @@ class Glm4MoeModelImpl : public torch::nn::Module { events[0] = input_params.layer_synchronizer->get_event(i); event_flags[0] = input_params.layer_synchronizer->get_event_flag(i); } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); + } + } auto& layer = layers_[i]; layer(h, diff --git a/xllm/models/llm/glm4_moe_mtp.h b/xllm/models/llm/glm4_moe_mtp.h index 5c005a24..0924e2e6 100644 --- a/xllm/models/llm/glm4_moe_mtp.h +++ b/xllm/models/llm/glm4_moe_mtp.h @@ -153,6 +153,12 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module { events[0] = input_params.layer_synchronizer->get_event(i); event_flags[0] = input_params.layer_synchronizer->get_event_flag(i); } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); + } + } + auto& layer = layers_[i]; layer(h, cos_pos, diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 156f8169..95488e54 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -287,6 +287,12 @@ class LlmModelImplBase : public torch::nn::Module { event_flags[j] = input_params[j].layer_synchronizer->get_event_flag(i); } + if (input_params[j].layer_wise_load_synchronizer != nullptr) { + if (!input_params[j].layer_wise_load_synchronizer->synchronize_layer( + i)) { + return torch::Tensor(); + } + } } auto& layer = layers_[i]; diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 9c1f1e90..3ff7d2f2 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -273,6 +273,12 @@ class Qwen3MoeModelImpl : public torch::nn::Module { event = input_params.layer_synchronizer->get_event(i); event_flag = input_params.layer_synchronizer->get_event_flag(i); } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); + } + } + auto& layer = layers_[i]; layer(h, cos_pos, diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 5344a2e8..fa728e31 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -57,17 +57,25 @@ message PullKVCacheRequest { repeated uint64 dst_blocks = 5; } -message CacheBlockInfo { - int32 device_block_id = 1; - int32 host_block_id = 2; +enum TransferType { + G2H = 0; + H2D = 1; + D2G = 2; +} + +message BlockTransferInfo { + int32 src_block_id = 1; + int32 dst_block_id = 2; bytes hash_key = 3; } -message CacheBlockInfos { - repeated CacheBlockInfo contents = 1; +message BlockTransferInfos { + uint64 batch_id = 1; + TransferType transfer_type = 2; + repeated BlockTransferInfo transfer_infos = 3; } -message StoreResponse { +message TransferStatus { uint32 success_cnt = 1; } @@ -180,10 +188,7 @@ message ForwardInput { repeated int32 embedding_ids = 25; repeated int32 extra_token_ids = 26; EplbInfo eplb_info =27; - repeated CacheBlockInfo async_copy_out_blocks = 28; - repeated CacheBlockInfo copy_out_blocks = 29; - repeated CacheBlockInfo copy_in_blocks = 30; - repeated CacheBlockInfo swap_blocks = 31; + repeated BlockTransferInfo swap_blocks = 28; // block copy kernel repeated int32 src_block_indices = 32; repeated int32 dst_block_indices = 33; @@ -193,6 +198,7 @@ message ForwardInput { repeated int64 kv_cache_start_offsets = 36; // beam search kernel repeated float acc_logprob_vec = 37; + uint64 batch_id = 38; } message BatchedForwardInputs { @@ -250,5 +256,6 @@ service DistributeWorker { rpc ExecuteModel (BatchedForwardInputs) returns (ForwardOutput); rpc GetLastStepResult (Empty) returns (ForwardOutput); rpc GetActiveActivationMemory (Empty) returns (ActivationMemory); - rpc LoadKVCacheFromStore(CacheBlockInfos) returns (StoreResponse) {} + rpc TransferBlocks(BlockTransferInfos) returns (TransferStatus) {} + rpc PrefetchFromStorage(BlockTransferInfos) returns (Status) {} } diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index 210dc43c..8ed57ad3 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -168,8 +168,9 @@ int run() { FLAGS_enable_prefix_cache && (FLAGS_host_blocks_factor > 0.0)) .store_protocol(FLAGS_store_protocol) - .store_master_server_entry(FLAGS_store_master_server_entry) - .store_metadata_connstring(FLAGS_store_metadata_connstring) + .store_master_server_address(FLAGS_store_master_server_address) + .store_metadata_server(FLAGS_store_metadata_server) + .store_local_hostname(FLAGS_store_local_hostname) .enable_multi_stream_parallel(FLAGS_enable_multi_stream_parallel) .enable_profile_step_time(FLAGS_enable_profile_step_time) .enable_profile_token_budget(FLAGS_enable_profile_token_budget)