Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ void initBindings(nb::module_& m)
.def_prop_ro("is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_prop_ro(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_prop_ro("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_prop_ro("is_context_init_state", &GenLlmReq::isContextInitState)
.def_prop_ro("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_prop_ro("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
Expand Down Expand Up @@ -252,7 +253,20 @@ void initBindings(nb::module_& m)
})
.def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", nb::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, nb::const_),
nb::arg("beam"))
.def("get_unique_tokens", nb::overload_cast<>(&GenLlmReq::getUniqueTokens, nb::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});

nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr())
.def(
Expand Down
11 changes: 10 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
nb::call_guard<nb::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, nb::arg("unique_tokens"),
nb::arg("llm_request"), nb::call_guard<nb::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard<nb::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
nb::call_guard<nb::gil_scoped_release>())
Expand Down Expand Up @@ -524,7 +526,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true,
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr,
nb::arg("enable_indexer_k_cache") = false, nb::arg("indexer_k_cache_quant_block_size") = 128,
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>());
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
nb::arg("num_required"), nb::arg("window_size"), nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}

void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)
Expand Down
16 changes: 15 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ void initBindings(pybind11::module_& m)
"is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_property_readonly(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_property_readonly("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_property_readonly("is_context_init_state", &GenLlmReq::isContextInitState)
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
Expand Down Expand Up @@ -258,7 +259,20 @@ void initBindings(pybind11::module_& m)
})
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, py::const_),
py::arg("beam"))
.def("get_unique_tokens", py::overload_cast<>(&GenLlmReq::getUniqueTokens, py::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});

py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
.def(py::init<>(
Expand Down
11 changes: 10 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard<py::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
py::call_guard<py::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, py::arg("unique_tokens"),
py::arg("llm_request"), py::call_guard<py::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard<py::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
py::call_guard<py::gil_scoped_release>())
Expand Down Expand Up @@ -519,7 +521,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
py::arg("kv_connector_manager") = nullptr, py::arg("enable_indexer_k_cache") = false,
py::arg("indexer_k_cache_quant_block_size") = 128, py::arg("indexer_k_cache_index_head_dim") = 0,
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
py::arg("num_required"), py::arg("window_size"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}

void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
Expand Down
34 changes: 24 additions & 10 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler,
TRTLLMSampler)
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
SimpleScheduler)
SimpleScheduler, SimpleUnifiedScheduler)
from .seq_slot_manager import SeqSlotManager

GB = 1 << 30
Expand Down Expand Up @@ -837,15 +837,29 @@ def create_py_executor_instance(
if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager:
scheduler_capacity += 1

capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1"
if use_python_scheduler:
scheduler = SimpleUnifiedScheduler(
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
kv_cache_manager=kv_cache_manager.impl
if kv_cache_manager is not None else None,
peft_cache_manager=peft_cache_manager.impl
if peft_cache_manager is not None else None,
scheduler_policy=scheduler_config.capacity_scheduler_policy,
ctx_chunk_config=ctx_chunk_config,
two_step_lookahead=mapping.has_pp(),
scheduler_capacity=scheduler_capacity)
else:
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)

config = model_engine.model.model_config.pretrained_config
attention_type = AttentionTypeCpp.MLA if is_mla(
Expand Down
Loading