Skip to content

Commit d1e93ce

Browse files
xinfei-shiLLLLKKKK
authored andcommitted
fix, modify block reserve logic
1 parent 9a20e23 commit d1e93ce

26 files changed

+116
-41
lines changed

docs/backend/pd_disaggregation.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
"| **LOAD_CACHE_TIMEOUT_MS** | Timeout for remote KVCache loading (milliseconds) | `5000` |\n",
124124
"| **DECODE_RETRY_TIMES** | Number of retries for decode process, 0 means retry disabled | `100` |\n",
125125
"| **DECODE_RETRY_TIMEOUT_MS** | Total timeout for decode process retries (milliseconds) | `100` |\n",
126+
"| **DECODE_RETRY_INTERVAL_MS** | interval for decode process retries (milliseconds) | `1` |\n",
126127
"| **RDMA_CONNECT_RETRY_TIMES** | Number of retries for RDMA connection establishment | `5000` |\n",
127128
"| **DECODE_POLLING_KV_CACHE_STEP_MS** | Interval time for polling KV loading status (milliseconds) | `30` |\n",
128129
"| **DECODE_ENTRANCE** | Whether Decode serves as traffic entry point | `false` |"

rtp_llm/config/gpt_init_model_parameters.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ class GptInitModelParameters:
240240
decode_polling_kv_cache_step_ms: int
241241
decode_retry_timeout_ms: int
242242
decode_retry_times: int
243+
decode_retry_interval: int
243244
deepseek_mscale_all_dim: float
244245
deepseek_rope_mscale: float
245246
dp_rank: int
@@ -1102,6 +1103,10 @@ def update_common(
11021103
self.py_env_configs.pd_separation_config.decode_retry_timeout_ms
11031104
)
11041105
logging.info(f"decode_retry_timeout_ms: {self.decode_retry_timeout_ms}")
1106+
self.decode_retry_interval_ms = (
1107+
self.py_env_configs.pd_separation_config.decode_retry_interval_ms
1108+
)
1109+
logging.info(f"decode_retry_interval_ms: {self.decode_retry_interval_ms}")
11051110

11061111
self.rdma_connect_retry_times = (
11071112
self.py_env_configs.pd_separation_config.rdma_connect_retry_times
@@ -1133,6 +1138,7 @@ def update_common(
11331138
logging.info(
11341139
f"scheduler_reserve_resource_ratio: {self.scheduler_reserve_resource_ratio}"
11351140
)
1141+
11361142
self.reuse_cache = self.py_env_configs.py_kv_cache_config.reuse_cache
11371143
logging.info(f"reuse_cache: {self.reuse_cache}")
11381144
self.pre_allocate_op_mem = bool(int(os.environ.get("PRE_ALLOCATE_OP_MEM", 1)))

rtp_llm/config/py_config_modules.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ def __init__(self):
648648
# Decode related configuration
649649
self.decode_retry_times: int = 100
650650
self.decode_retry_timeout_ms: int = 100
651+
self.decode_retry_interval_ms: int = 1
651652
self.decode_polling_kv_cache_step_ms: int = 30
652653
self.decode_entrance: int = 0
653654

@@ -676,6 +677,9 @@ def update_from_env(self):
676677
self.decode_retry_timeout_ms = int(
677678
os.environ.get("DECODE_RETRY_TIMEOUT_MS", self.decode_retry_timeout_ms)
678679
)
680+
self.decode_retry_interval_ms = int(
681+
os.environ.get("DECODE_RETRY_INTERVAL_MS", self.decode_retry_interval_ms)
682+
)
679683
self.decode_polling_kv_cache_step_ms = int(
680684
os.environ.get(
681685
"DECODE_POLLING_KV_CACHE_STEP_MS", self.decode_polling_kv_cache_step_ms
@@ -700,6 +704,7 @@ def to_string(self):
700704
f"prefill_max_wait_timeout_ms: {self.prefill_max_wait_timeout_ms}\n"
701705
f"decode_retry_times: {self.decode_retry_times}\n"
702706
f"decode_retry_timeout_ms: {self.decode_retry_timeout_ms}\n"
707+
f"decode_retry_interval_ms: {self.decode_retry_interval_ms}\n"
703708
f"decode_polling_kv_cache_step_ms: {self.decode_polling_kv_cache_step_ms}\n"
704709
f"decode_entrance: {self.decode_entrance}\n"
705710
f"rdma_connect_retry_times: {self.rdma_connect_retry_times}\n"

rtp_llm/cpp/api_server/test/mock/MockEngineBase.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class MockEngineBase: public EngineBase {
1919
std::vector<GenerateStreamPtr>(const std::vector<std::shared_ptr<GenerateInput>>& inputs));
2020
MOCK_METHOD0(stop, absl::Status());
2121
MOCK_METHOD2(preRun, absl::StatusOr<GenerateStreamPtr>(const std::shared_ptr<GenerateInput>&, preRunMode));
22-
MOCK_METHOD(KVCacheInfo, getCacheStatusInfo, (int64_t, bool), (const, override));
22+
MOCK_METHOD(KVCacheInfo, getCacheStatusInfo, (int64_t, bool), (override));
2323
};
2424

2525
} // namespace rtp_llm

rtp_llm/cpp/cache/CacheManager.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ CacheManager::~CacheManager() {
9292
allocator_.reset();
9393
}
9494

95-
uint32_t CacheManager::totalBlocks() const {
95+
size_t CacheManager::totalBlocks() const {
9696
return allocator_->totalBlocks();
9797
}
9898

99-
uint32_t CacheManager::maxSeqLen() const {
99+
size_t CacheManager::maxSeqLen() const {
100100
return totalBlocks() * seq_size_per_block_;
101101
}
102102

@@ -110,7 +110,7 @@ void CacheManager::reportMetricsLoop() {
110110
{
111111
std::lock_guard<std::mutex> guard(mutex_);
112112
collector.kv_cache_item_num = block_cache_.size();
113-
auto available_blocks = availableBlockNums();
113+
auto available_blocks = availableBlockNumsWithoutLock();
114114
collector.kv_cache_left_seq = available_blocks * seq_size_per_block_;
115115
collector.kv_cache_available_blocks = available_blocks;
116116
collector.kv_cache_free_blocks = freeBlockNums();
@@ -156,11 +156,16 @@ size_t CacheManager::freeBlockNums() const {
156156
return allocator_->freeBlockNums();
157157
}
158158

159-
size_t CacheManager::availableBlockNums() const {
159+
size_t CacheManager::availableBlockNums() {
160+
std::lock_guard<std::mutex> guard(mutex_);
161+
return available_blocks_;
162+
}
163+
164+
size_t CacheManager::availableBlockNumsWithoutLock() {
160165
return available_blocks_;
161166
}
162167

163-
KVCacheInfo CacheManager::getKVCacheInfo(int64_t latest_version, bool need_cache_keys) const {
168+
KVCacheInfo CacheManager::getKVCacheInfo(int64_t latest_version, bool need_cache_keys) {
164169
auto snapshot = block_cache_.cacheSnapshot(latest_version);
165170
std::vector<int64_t> cachekeys;
166171
if (need_cache_keys) {

rtp_llm/cpp/cache/CacheManager.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,10 @@ class CacheManager {
103103

104104
const CacheConfig& cacheConfig() const;
105105
size_t freeBlockNums() const;
106-
size_t availableBlockNums() const;
107-
KVCacheInfo getKVCacheInfo(int64_t latest_version, bool need_cache_keys) const;
108-
uint32_t maxSeqLen() const;
106+
size_t availableBlockNums();
107+
size_t totalBlocks() const;
108+
size_t maxSeqLen() const;
109+
KVCacheInfo getKVCacheInfo(int64_t latest_version, bool need_cache_keys);
109110
const KVCacheAllocator::KVCacheBuffer& kvCacheBuffer() const;
110111

111112
std::tuple<bool, KVCacheResource> malloc(const KVCacheAllocator::SimpleMallocInfo& malloc_info);
@@ -150,10 +151,10 @@ class CacheManager {
150151
protected:
151152
const BlockCache& blockCache() const;
152153
size_t cacheItemNum() const;
153-
uint32_t totalBlocks() const;
154154
void initFreeBlock();
155155
rtp_llm::BufferPtr tryAllocateMaxBuffer();
156156
void allocateAndSync();
157+
size_t availableBlockNumsWithoutLock();
157158

158159
MatchInfo matchImpl(const AdvancedMallocInfo& malloc_info);
159160
std::tuple<bool, std::vector<int>> mallocIndex(const KVCacheAllocator::SimpleMallocInfo& malloc_info);

rtp_llm/cpp/config/ConfigModules.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,12 @@ struct BatchDecodeSchedulerConfig {
198198
};
199199

200200
struct FIFOSchedulerConfig {
201-
int64_t max_context_batch_size = 1;
202-
int scheduler_reserve_resource_ratio = 5;
203-
bool enable_fast_gen = false;
204-
bool enable_partial_fallback = false;
205-
int64_t fast_gen_context_budget = -1;
201+
int64_t max_context_batch_size = 1;
202+
int scheduler_reserve_resource_ratio = 5;
203+
bool enable_fast_gen = false;
204+
bool enable_partial_fallback = false;
205+
int64_t fast_gen_context_budget = -1;
206+
206207
std::string to_string() const;
207208
void update_from_env_for_test();
208209
};

rtp_llm/cpp/config/GptInitParameter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ class GptInitParameter {
252252
int64_t prefill_max_wait_timeout_ms_ = 0;
253253
int64_t decode_retry_times_ = 0;
254254
int64_t decode_retry_timeout_ms_ = 0;
255+
int64_t decode_retry_interval_ms_ = 1;
255256
int64_t decode_polling_kv_cache_step_ms_ = 0;
256257
int64_t decode_polling_call_prefill_ms_ = 0;
257258
int64_t rdma_connect_retry_times_ = 0;

rtp_llm/cpp/engine_base/EngineBase.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class EngineBase {
8383
virtual absl::StatusOr<GenerateStreamPtr> preRun(const std::shared_ptr<GenerateInput>& generate_input,
8484
preRunMode mode) = 0;
8585

86-
virtual KVCacheInfo getCacheStatusInfo(int64_t latest_version, bool need_cache_keys) const = 0;
86+
virtual KVCacheInfo getCacheStatusInfo(int64_t latest_version, bool need_cache_keys) = 0;
8787

8888
virtual const ResourceContext& resourceContext() const {
8989
return resource_context_;

rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.cc

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@ FIFOScheduler::FIFOScheduler(const rtp_llm::GptInitParameter& params,
1717
max_seq_len_(params.max_seq_len_),
1818
max_batch_tokens_size_(params.max_batch_tokens_size_),
1919
max_generate_batch_size_(params.max_generate_batch_size_),
20-
reserve_block_num_(params.scheduler_reserve_resource_ratio_ * cache_manager->availableBlockNums() / 100),
2120
// not support fallback when use pd_speration:use_cache_store
2221
enable_partial_fallback_(params.enable_partial_fallback_ && params.role_type_ == RoleType::PDFUSION),
2322
enable_whole_fallback_(params.role_type_ == RoleType::PDFUSION),
2423
enable_fast_gen_(params.enable_fast_gen_),
2524
need_fill_fake_stream_(params.dp_size_ > 1 && params.tp_rank_ == 0),
2625
fast_gen_max_context_len_(params.fast_gen_max_context_len_),
2726
metrics_reporter_(metrics_reporter) {
28-
RTP_LLM_LOG_INFO("max_generate_batch_size %d", max_generate_batch_size_);
29-
RTP_LLM_LOG_INFO("max_batch_tokens_size %d", max_batch_tokens_size_);
27+
reserve_block_num_ = params.scheduler_reserve_resource_ratio_ * cache_manager->availableBlockNums() / 100;
28+
RTP_LLM_LOG_INFO("max_generate_batch_size is [%d], max_batch_tokens_size is [%d], reserve_block_num is [%d]",
29+
max_generate_batch_size_,
30+
max_batch_tokens_size_,
31+
reserve_block_num_);
3032
}
3133

3234
FIFOScheduler::~FIFOScheduler() {
@@ -228,13 +230,27 @@ bool FIFOScheduler::evaluateNewStream(const list<GenerateStreamPtr>& streams,
228230
return false;
229231
}
230232

231-
auto result = new_stream->initKVBlock(token_capacity_, reserve_step);
233+
auto old_blocks = new_stream->maxBlockSize();
234+
auto result = new_stream->initKVBlock(token_capacity_, reserve_step);
232235
if (result.ok() && enable_fast_gen_) {
233236
token_capacity_ -= result.value();
234237
RTP_LLM_LOG_DEBUG(
235238
"after stream [%ld] acquireCapacity, token_capacity is %d", new_stream->streamId(), token_capacity_);
236239
}
237-
return result.ok() && cache_manager_->availableBlockNums() >= reserve_block_num_;
240+
if (result.ok()) {
241+
if (cache_manager_->availableBlockNums() >= reserve_block_num_) {
242+
return true;
243+
} else {
244+
RTP_LLM_LOG_INFO(
245+
"current availableBlockNums is [%ld], reserve_block_num is [%ld], so stream [%ld] malloc failed",
246+
cache_manager_->availableBlockNums(),
247+
reserve_block_num_,
248+
new_stream->streamId());
249+
new_stream->tryReleaseKVBlock(new_stream->maxBlockSize() - old_blocks);
250+
return false;
251+
}
252+
}
253+
return false;
238254
}
239255

240256
list<GenerateStreamPtr> FIFOScheduler::scheduleNew(size_t reserve_step) {

0 commit comments

Comments
 (0)