Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
64bce0f
re-implement micro batch scheduler and capacity scheduler in python
QiJune Dec 17, 2025
034fffb
refine
QiJune Dec 17, 2025
927b417
enable SimpleUnifiedScheduler
QiJune Dec 17, 2025
3609b20
fix
QiJune Dec 17, 2025
c901b21
fix
QiJune Dec 17, 2025
84cebc9
fix
QiJune Dec 17, 2025
490f8e9
fix
QiJune Dec 17, 2025
4e62403
fix
QiJune Dec 17, 2025
87caccb
fix
QiJune Dec 17, 2025
d1aebe7
fix
QiJune Dec 17, 2025
4d1f530
fix
QiJune Dec 17, 2025
641236d
fix
QiJune Dec 17, 2025
162d59e
enable py scheduler
QiJune Dec 17, 2025
707fb4a
support bert
QiJune Dec 17, 2025
fbc8486
fix
QiJune Dec 18, 2025
d344670
fix
QiJune Dec 18, 2025
6617a47
fix
QiJune Dec 18, 2025
63c09c6
fix
QiJune Dec 18, 2025
c2bffa5
fix
QiJune Dec 18, 2025
2a3a7f2
fix gemma
QiJune Dec 19, 2025
2f30b99
fix lora
QiJune Dec 19, 2025
411c254
implement scheduler using python
lancelly Dec 24, 2025
80b7253
fix part of CI failues by exposing more c++ api
lancelly Dec 25, 2025
bc443d0
Merge branch 'main' into unified_python_scheduler
lancelly Dec 26, 2025
7a26528
fix scheduler capacity for disagg gen init reqs
lancelly Dec 26, 2025
4b65790
use cpp scheduler by default for now
lancelly Dec 27, 2025
d7fb900
add stats
lancelly Dec 28, 2025
f6cf566
optimize scheduler after profiling wiht line_profiler
lancelly Jan 6, 2026
6af3e00
enable py scheduler for ci
lancelly Jan 6, 2026
82fac4d
disable python scheduler by default
lancelly Jan 14, 2026
f8c4a57
add ut for pyscheduler
lancelly Jan 18, 2026
baeee83
use the buildtin types dict, set and tuple
lancelly Jan 18, 2026
fc794cb
fix ut
lancelly Jan 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ void initBindings(nb::module_& m)
.def_prop_ro("is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_prop_ro(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_prop_ro("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_prop_ro("is_context_init_state", &GenLlmReq::isContextInitState)
.def_prop_ro("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_prop_ro("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
Expand Down Expand Up @@ -252,7 +253,20 @@ void initBindings(nb::module_& m)
})
.def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", nb::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, nb::const_),
nb::arg("beam"))
.def("get_unique_tokens", nb::overload_cast<>(&GenLlmReq::getUniqueTokens, nb::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});

nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr())
.def(
Expand Down
11 changes: 10 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
nb::call_guard<nb::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, nb::arg("unique_tokens"),
nb::arg("llm_request"), nb::call_guard<nb::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard<nb::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
nb::call_guard<nb::gil_scoped_release>())
Expand Down Expand Up @@ -524,7 +526,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true,
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr,
nb::arg("enable_indexer_k_cache") = false, nb::arg("indexer_k_cache_quant_block_size") = 128,
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>());
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
nb::arg("num_required"), nb::arg("window_size"), nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}

void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)
Expand Down
16 changes: 15 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ void initBindings(pybind11::module_& m)
"is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_property_readonly(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_property_readonly("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_property_readonly("is_context_init_state", &GenLlmReq::isContextInitState)
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
Expand Down Expand Up @@ -258,7 +259,20 @@ void initBindings(pybind11::module_& m)
})
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, py::const_),
py::arg("beam"))
.def("get_unique_tokens", py::overload_cast<>(&GenLlmReq::getUniqueTokens, py::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});

py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
.def(py::init<>(
Expand Down
11 changes: 10 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard<py::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
py::call_guard<py::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, py::arg("unique_tokens"),
py::arg("llm_request"), py::call_guard<py::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard<py::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
py::call_guard<py::gil_scoped_release>())
Expand Down Expand Up @@ -519,7 +521,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
py::arg("kv_connector_manager") = nullptr, py::arg("enable_indexer_k_cache") = false,
py::arg("indexer_k_cache_quant_block_size") = 128, py::arg("indexer_k_cache_index_head_dim") = 0,
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
py::arg("num_required"), py::arg("window_size"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}

void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# Disable UCC to WAR allgather issue before NGC PyTorch 25.12 upgrade.
os.environ["OMPI_MCA_coll_ucc_enable"] = "0"
os.environ["TLLM_USE_PYTHON_SCHEDULER"] = "1"


def _add_trt_llm_dll_directory():
Expand Down
34 changes: 24 additions & 10 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler,
TRTLLMSampler)
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
SimpleScheduler)
SimpleScheduler, SimpleUnifiedScheduler)
from .seq_slot_manager import SeqSlotManager

GB = 1 << 30
Expand Down Expand Up @@ -837,15 +837,29 @@ def create_py_executor_instance(
if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager:
scheduler_capacity += 1

capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1"
if use_python_scheduler:
scheduler = SimpleUnifiedScheduler(
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
kv_cache_manager=kv_cache_manager.impl
if kv_cache_manager is not None else None,
peft_cache_manager=peft_cache_manager.impl
if peft_cache_manager is not None else None,
scheduler_policy=scheduler_config.capacity_scheduler_policy,
ctx_chunk_config=ctx_chunk_config,
two_step_lookahead=mapping.has_pp(),
scheduler_capacity=scheduler_capacity)
else:
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)

config = model_engine.model.model_config.pretrained_config
attention_type = AttentionTypeCpp.MLA if is_mla(
Expand Down
Loading