From 6e382109915aadde6689086017a0ff9f18ba9035 Mon Sep 17 00:00:00 2001 From: xiongjun3 Date: Wed, 3 Sep 2025 10:52:43 +0800 Subject: [PATCH] feat: add TTFT predictor. --- CMakeLists.txt | 1 + vcpkg.json | 4 ++ xllm_service/common/CMakeLists.txt | 2 + xllm_service/common/ttft_predictor.cpp | 59 +++++++++++++++++++ xllm_service/common/ttft_predictor.h | 35 +++++++++++ xllm_service/common/types.h | 9 +++ .../scheduler/managers/instance_mgr.cpp | 11 ++++ .../scheduler/managers/instance_mgr.h | 4 +- 8 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 xllm_service/common/ttft_predictor.cpp create mode 100644 xllm_service/common/ttft_predictor.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c57142..6af1209 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,6 +99,7 @@ endif() # find all dependencies from vcpkg find_package(Boost REQUIRED) find_package(Boost REQUIRED COMPONENTS serialization) +find_package(Eigen3 CONFIG REQUIRED) find_package(glog CONFIG REQUIRED) find_package(gflags CONFIG REQUIRED) find_package(leveldb CONFIG REQUIRED) diff --git a/vcpkg.json b/vcpkg.json index d57f269..de4fce1 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -42,6 +42,10 @@ "name": "boost-serialization", "version>=": "1.84.0" }, + { + "name": "eigen3", + "version>=": "3.4.0" + }, { "name": "protobuf", "version>=": "3.21.12", diff --git a/xllm_service/common/CMakeLists.txt b/xllm_service/common/CMakeLists.txt index cf4ba67..3afb9cf 100644 --- a/xllm_service/common/CMakeLists.txt +++ b/xllm_service/common/CMakeLists.txt @@ -12,6 +12,7 @@ cc_library( macros.h slice.h threadpool.h + ttft_predictor.h types.h utils.h hash_util.h @@ -22,6 +23,7 @@ cc_library( global_gflags.cpp json_reader.cpp threadpool.cpp + ttft_predictor.cpp utils.cpp hash_util.cpp xllm/uuid.cpp diff --git a/xllm_service/common/ttft_predictor.cpp b/xllm_service/common/ttft_predictor.cpp new file mode 100644 index 0000000..39e52c1 --- /dev/null +++ b/xllm_service/common/ttft_predictor.cpp @@ -0,0 +1,59 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ttft_predictor.h" + +static constexpr int32_t kDegree = 2; + +namespace xllm_service { + +TtftPredictor::TtftPredictor( + const std::vector>& ttft_profiling_data) { + if (!ttft_profiling_data.empty()) { + // construct Vandermonde matrix + int32_t m = ttft_profiling_data.size(); + int32_t n = kDegree + 1; + Eigen::MatrixXd matrix(m, n); + for (int32_t i = 0; i < m; ++i) { + for (int32_t j = 0; j < n; ++j) { + matrix(i, j) = std::pow(ttft_profiling_data[i].first, j); + } + } + + // construct target vector + Eigen::VectorXd target(m); + for (int32_t i = 0; i < m; ++i) { + target(i) = ttft_profiling_data[i].second; + } + + // get coefficients + coefficients_ = matrix.colPivHouseholderQr().solve(target); + } else { + coefficients_ = Eigen::VectorXd::Zero(1); + } +} + +int64_t TtftPredictor::predict_ttft(int32_t length) { + double result = 0.0; + double power = 1.0; + for (int32_t i = 0; i < coefficients_.size(); ++i) { + result += coefficients_(i) * power; + power *= length; + } + + return static_cast(result); +} + +} // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/common/ttft_predictor.h b/xllm_service/common/ttft_predictor.h new file mode 100644 index 0000000..0aa3344 --- /dev/null +++ b/xllm_service/common/ttft_predictor.h @@ -0,0 +1,35 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +namespace xllm_service { + +// Predictor for predicting TTFT based on input length +class TtftPredictor final { + public: + TtftPredictor( + const std::vector>& ttft_profiling_data); + ~TtftPredictor() = default; + + int64_t predict_ttft(int32_t length); + + private: + Eigen::VectorXd coefficients_; +}; + +} // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/common/types.h b/xllm_service/common/types.h index db11e71..d224052 100644 --- a/xllm_service/common/types.h +++ b/xllm_service/common/types.h @@ -139,6 +139,8 @@ struct InstanceMetaInfo { std::vector k_cache_ids; std::vector v_cache_ids; int32_t dp_size; + // ttft profiling data + std::vector> ttft_profiling_data; // latest heatbeat timestamp uint64_t latest_timestamp = 0; @@ -155,6 +157,7 @@ struct InstanceMetaInfo { json_val["k_cache_ids"] = k_cache_ids; json_val["v_cache_ids"] = v_cache_ids; json_val["dp_size"] = dp_size; + json_val["ttft_profiling_data"] = ttft_profiling_data; return json_val; } @@ -189,6 +192,12 @@ struct InstanceMetaInfo { dp_size = json_value.at("dp_size").get(); + for (const auto& item : json_value.at("ttft_profiling_data")) { + if (item.is_array() && item.size() == 2) { + ttft_profiling_data.emplace_back(item[0], item[1]); + } + } + set_init_timestamp(); } catch (const std::exception& e) { LOG(ERROR) << "json str:" << json_str diff --git a/xllm_service/scheduler/managers/instance_mgr.cpp b/xllm_service/scheduler/managers/instance_mgr.cpp index ca57b2d..2b04c41 100644 --- a/xllm_service/scheduler/managers/instance_mgr.cpp +++ b/xllm_service/scheduler/managers/instance_mgr.cpp @@ -66,6 +66,11 @@ void InstanceMgr::init() { for (auto& it : ETCD_KEYS_PREFIX_MAP) { etcd_client_->get_prefix(it.second, &instances_); } + // create ttft predictor for each instance + for (auto& pair : instances_) { + ttft_predictors_.insert_or_assign( + pair.first, TtftPredictor(pair.second.ttft_profiling_data)); + } LOG(INFO) << "Load instance info from etcd:" << instances_.size(); std::vector channel_creat_fail_insts; prefill_index_.reserve(instances_.size()); @@ -94,6 +99,7 @@ void InstanceMgr::init() { } for (auto& name : channel_creat_fail_insts) { instances_.erase(name); + ttft_predictors_.erase(name); } } { @@ -334,6 +340,10 @@ void InstanceMgr::update_instance_metainfo(const etcd::Response& response, continue; } + // create ttft predictor for instance + ttft_predictors_.emplace( + iter.first, TtftPredictor(iter.second.ttft_profiling_data)); + instances_.insert(std::make_pair(iter.first, std::move(iter.second))); switch (iter.second.type) { @@ -385,6 +395,7 @@ void InstanceMgr::update_instance_metainfo(const etcd::Response& response, } instances_.erase(iter); + ttft_predictors_.erase(iter); cached_channels_.erase(iter); { std::lock_guard lock(update_mutex_); diff --git a/xllm_service/scheduler/managers/instance_mgr.h b/xllm_service/scheduler/managers/instance_mgr.h index e3c4935..4915aac 100644 --- a/xllm_service/scheduler/managers/instance_mgr.h +++ b/xllm_service/scheduler/managers/instance_mgr.h @@ -22,11 +22,12 @@ limitations under the License. #include #include -#include "../etcd_client/etcd_client.h" #include "common/macros.h" #include "common/options.h" #include "common/threadpool.h" +#include "common/ttft_predictor.h" #include "common/types.h" +#include "scheduler/etcd_client/etcd_client.h" #include "xllm_rpc_service.pb.h" namespace xllm_service { @@ -80,6 +81,7 @@ class InstanceMgr final { std::shared_mutex inst_mutex_; std::unordered_map instances_; + std::unordered_map ttft_predictors_; std::vector prefill_index_; std::vector decode_index_; uint64_t next_prefill_index_ = 0;