Skip to content

Commit dbb858a

Browse files
lancellyQiJune
andauthored
[TRTLLM-10029][scheduler] Re-implement MicroBatchScheduler and CapacityScheduler in Python (#10273)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com> Signed-off-by: Lance Liao <108499334+lancelly@users.noreply.github.com> Co-authored-by: junq <22017000+QiJune@users.noreply.github.com> Co-authored-by: Lanyu Liao <lancelly@users.noreply.github.com>
1 parent c6320d9 commit dbb858a

File tree

8 files changed

+1167
-76
lines changed

8 files changed

+1167
-76
lines changed

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ void initBindings(nb::module_& m)
132132
.def_rw("max_new_tokens", &GenLlmReq::mMaxNewTokens)
133133
.def_rw("sampling_config", &GenLlmReq::mSamplingConfig)
134134
.def_prop_rw("state", &GenLlmReq::getState, &GenLlmReq::setState)
135+
.def_prop_ro("state_value", [](GenLlmReq const& self) { return static_cast<int>(self.getState()); })
135136
.def_prop_rw("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming)
136137
.def_rw("end_id", &GenLlmReq::mEndId)
137138
.def_rw("pad_id", &GenLlmReq::mPadId)
@@ -175,6 +176,7 @@ void initBindings(nb::module_& m)
175176
.def_prop_ro("is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
176177
.def_prop_ro(
177178
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
179+
.def_prop_ro("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
178180
.def_prop_ro("is_context_init_state", &GenLlmReq::isContextInitState)
179181
.def_prop_ro("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
180182
.def_prop_ro("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
@@ -253,7 +255,20 @@ void initBindings(nb::module_& m)
253255
})
254256
.def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
255257
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
256-
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
258+
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
259+
.def("get_unique_tokens", nb::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, nb::const_),
260+
nb::arg("beam"))
261+
.def("get_unique_tokens", nb::overload_cast<>(&GenLlmReq::getUniqueTokens, nb::const_))
262+
.def("get_encoder_unique_tokens",
263+
[](GenLlmReq& self)
264+
{
265+
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
266+
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
267+
{
268+
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
269+
}
270+
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
271+
});
257272

258273
nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr())
259274
.def(

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
481481
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard<nb::gil_scoped_release>())
482482
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
483483
nb::call_guard<nb::gil_scoped_release>())
484+
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, nb::arg("unique_tokens"),
485+
nb::arg("llm_request"), nb::call_guard<nb::gil_scoped_release>())
484486
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard<nb::gil_scoped_release>())
485487
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
486488
nb::call_guard<nb::gil_scoped_release>())
@@ -524,7 +526,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
524526
nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true,
525527
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr,
526528
nb::arg("enable_indexer_k_cache") = false, nb::arg("indexer_k_cache_quant_block_size") = 128,
527-
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>());
529+
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>())
530+
.def(
531+
"scheduling_has_free_blocks",
532+
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
533+
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
534+
nb::arg("num_required"), nb::arg("window_size"), nb::call_guard<nb::gil_scoped_release>())
535+
.def_prop_ro(
536+
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
528537
}
529538

530539
void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ void initBindings(pybind11::module_& m)
136136
.def_readwrite("max_new_tokens", &GenLlmReq::mMaxNewTokens)
137137
.def_readwrite("sampling_config", &GenLlmReq::mSamplingConfig)
138138
.def_property("state", &GenLlmReq::getState, &GenLlmReq::setState)
139+
.def_property_readonly("state_value", [](GenLlmReq const& self) { return static_cast<int>(self.getState()); })
139140
.def_property("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming)
140141
.def_readwrite("end_id", &GenLlmReq::mEndId)
141142
.def_readwrite("pad_id", &GenLlmReq::mPadId)
@@ -181,6 +182,7 @@ void initBindings(pybind11::module_& m)
181182
"is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
182183
.def_property_readonly(
183184
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
185+
.def_property_readonly("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
184186
.def_property_readonly("is_context_init_state", &GenLlmReq::isContextInitState)
185187
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
186188
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
@@ -259,7 +261,20 @@ void initBindings(pybind11::module_& m)
259261
})
260262
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
261263
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
262-
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
264+
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
265+
.def("get_unique_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, py::const_),
266+
py::arg("beam"))
267+
.def("get_unique_tokens", py::overload_cast<>(&GenLlmReq::getUniqueTokens, py::const_))
268+
.def("get_encoder_unique_tokens",
269+
[](GenLlmReq& self)
270+
{
271+
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
272+
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
273+
{
274+
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
275+
}
276+
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
277+
});
263278

264279
py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
265280
.def(py::init<>(

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
485485
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard<py::gil_scoped_release>())
486486
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
487487
py::call_guard<py::gil_scoped_release>())
488+
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, py::arg("unique_tokens"),
489+
py::arg("llm_request"), py::call_guard<py::gil_scoped_release>())
488490
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard<py::gil_scoped_release>())
489491
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
490492
py::call_guard<py::gil_scoped_release>())
@@ -519,7 +521,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
519521
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
520522
py::arg("kv_connector_manager") = nullptr, py::arg("enable_indexer_k_cache") = false,
521523
py::arg("indexer_k_cache_quant_block_size") = 128, py::arg("indexer_k_cache_index_head_dim") = 0,
522-
py::call_guard<py::gil_scoped_release>());
524+
py::call_guard<py::gil_scoped_release>())
525+
.def(
526+
"scheduling_has_free_blocks",
527+
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
528+
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
529+
py::arg("num_required"), py::arg("window_size"), py::call_guard<py::gil_scoped_release>())
530+
.def_property_readonly(
531+
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
523532
}
524533

525534
void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler,
4040
TRTLLMSampler)
4141
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
42-
SimpleScheduler)
42+
SimpleScheduler, SimpleUnifiedScheduler)
4343
from .seq_slot_manager import SeqSlotManager
4444

4545
GB = 1 << 30
@@ -852,15 +852,29 @@ def create_py_executor_instance(
852852
if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager:
853853
scheduler_capacity += 1
854854

855-
capacity_scheduler = BindCapacityScheduler(
856-
scheduler_capacity,
857-
kv_cache_manager.impl if kv_cache_manager is not None else None,
858-
peft_cache_manager.impl if peft_cache_manager is not None else None,
859-
scheduler_config.capacity_scheduler_policy,
860-
two_step_lookahead=mapping.has_pp())
861-
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
862-
ctx_chunk_config)
863-
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
855+
use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1"
856+
if use_python_scheduler:
857+
scheduler = SimpleUnifiedScheduler(
858+
max_batch_size=max_batch_size,
859+
max_num_tokens=max_num_tokens,
860+
kv_cache_manager=kv_cache_manager.impl
861+
if kv_cache_manager is not None else None,
862+
peft_cache_manager=peft_cache_manager.impl
863+
if peft_cache_manager is not None else None,
864+
scheduler_policy=scheduler_config.capacity_scheduler_policy,
865+
ctx_chunk_config=ctx_chunk_config,
866+
two_step_lookahead=mapping.has_pp(),
867+
scheduler_capacity=scheduler_capacity)
868+
else:
869+
capacity_scheduler = BindCapacityScheduler(
870+
scheduler_capacity,
871+
kv_cache_manager.impl if kv_cache_manager is not None else None,
872+
peft_cache_manager.impl if peft_cache_manager is not None else None,
873+
scheduler_config.capacity_scheduler_policy,
874+
two_step_lookahead=mapping.has_pp())
875+
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
876+
ctx_chunk_config)
877+
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
864878

865879
config = model_engine.model.model_config.pretrained_config
866880
attention_type = AttentionTypeCpp.MLA if is_mla(

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,6 +2041,7 @@ def _waiting_requests(self, context_requests: list[LlmRequest],
20412041
def _schedule(self):
20422042
scheduler_output = self.scheduler.schedule_request(
20432043
self.active_requests, self.inflight_req_ids)
2044+
20442045
scheduled_context_requests = scheduler_output.context_requests
20452046
if self.enable_attention_dp and self.attention_dp_enable_balance:
20462047
scheduled_context_requests = self._balance_adp_requests(
@@ -2060,6 +2061,7 @@ def _schedule(self):
20602061
scheduled_requests.context_requests = scheduled_context_requests
20612062
scheduled_requests.generation_requests = scheduler_output.generation_requests
20622063
scheduled_requests.paused_requests = scheduler_output.paused_requests
2064+
20632065
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests
20642066

20652067
@nvtx_range("_check_disagg_gen_transfer_status")

0 commit comments

Comments
 (0)