diff --git a/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h b/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h index 70690411797..1b2ac21b2ce 100644 --- a/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h +++ b/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h @@ -28,7 +28,9 @@ namespace tensorrt_llm::batch_manager namespace kv_cache_manager { class BaseKVCacheManager; -} +struct BlockKey; +struct BlockKeyHasher; +} // namespace kv_cache_manager class BasePeftCacheManager; } // namespace tensorrt_llm::batch_manager @@ -60,6 +62,29 @@ class BaseCapacityScheduler return mNoScheduleAfterState; } + /// @brief Check if a request should be skipped from scheduling + /// @param req The request to check + /// @param allowDisaggGeneration Allow disagg_generation_init requests to be scheduled + /// @return true if the request should be skipped, false otherwise + [[nodiscard]] bool shouldSkipRequest( + std::shared_ptr const& req, bool allowDisaggGeneration = true) const + { + // Check basic scheduling conditions first + bool basicSkipCondition + = !req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState()); + + if (allowDisaggGeneration) + { + // Allow disagg_generation_init requests to be scheduled, so they're excluded from the disagg check + return !req->isDisaggGenerationInitState() && basicSkipCondition; + } + else + { + // Don't make exception for disagg_generation_init requests + return basicSkipCondition; + } + } + private: /// The state until/after which the scheduler should not schedule requests LlmRequestState mNoScheduleUntilState; @@ -140,6 +165,49 @@ class StaticBatchScheduler : public GuaranteedNoEvictScheduler OptionalRef peftCacheManager, RequestList const& activeRequests) const; }; +/// @brief Schedule requests with non-mix batching strategy +/// @details Ensures that each batch contains only context requests OR only generation requests. +/// Prevents mixing of request types in batches, with context having priority. +class NonMixBatchingScheduler : public BaseCapacityScheduler +{ +public: + NonMixBatchingScheduler(SizeType32 maxNumRequests, + LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, + LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE); + + [[nodiscard]] std::tuple operator()( + kv_cache_manager::BaseKVCacheManager const& kvCacheManager, + OptionalRef crossKvCacheManager, + OptionalRef peftCacheManager, RequestList const& activeRequests) const; + +private: + SizeType32 mMaxNumRequests; + /// @brief Records whether the last scheduled batch was context phase (true) or generation phase (false) + /// Initialize to false to prioritize context requests in the first scheduling cycle + mutable bool mLastWasContext{false}; + + /// @brief Categorize active requests into context and generation types + [[nodiscard]] std::pair categorizeRequests(RequestList const& activeRequests) const; + + /// @brief Determine which request type to schedule based on availability and alternating policy + [[nodiscard]] bool determineSchedulingStrategy( + RequestVector const& contextRequests, RequestVector const& generationRequests) const; + + /// @brief Schedule requests from the selected request type with resource management + [[nodiscard]] RequestVector scheduleRequests(RequestVector const& requestsToSchedule, bool shouldScheduleContext, + bool skippingIsRelevant, kv_cache_manager::BaseKVCacheManager const& kvCacheManager, + OptionalRef crossKvCacheManager, + OptionalRef peftCacheManager, + std::unordered_set& newlyContributedContextBlocks, + std::unordered_set& + newlyContributedCrossContextBlocks) const; + + /// @brief Check if PEFT resources are available for the request + [[nodiscard]] bool checkPeftResources(std::shared_ptr const& req, + OptionalRef peftCacheManager, SizeType32 maxPeftCachePages, + SizeType32& claimedPeftPages, std::unordered_set& uniqTaskIds) const; +}; + class CapacityScheduler : public Algorithm { public: @@ -169,7 +237,7 @@ class CapacityScheduler : public Algorithm private: std::variant + StaticBatchScheduler, NonMixBatchingScheduler> mScheduler; }; diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index e248cb1c3cd..c5eaf68a329 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -206,7 +206,11 @@ enum class CapacitySchedulerPolicy /// @brief kSTATIC_BATCH does not schedule new requests until all requests in current batch are completed. /// Similar to kGUARANTEED_NO_EVICT, requests will run to completion without eviction. - kSTATIC_BATCH = 2 + kSTATIC_BATCH = 2, + + /// @brief kNON_MIX_BATCHING ensures each batch contains only context OR only generation requests. + /// Prevents mixing of context and generation requests in the same batch, with context having priority. + kNON_MIX_BATCHING = 3 }; std::ostream& operator<<(std::ostream& os, CapacitySchedulerPolicy policy); diff --git a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp index d765bcf3173..041d6063408 100644 --- a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp +++ b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp @@ -17,6 +17,7 @@ #include "tensorrt_llm/batch_manager/capacityScheduler.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/peftCacheManager.h" #include "tensorrt_llm/batch_manager/scheduledBlocksManager.h" #include "tensorrt_llm/common/logger.h" @@ -151,13 +152,20 @@ StaticBatchScheduler::StaticBatchScheduler( { } +NonMixBatchingScheduler::NonMixBatchingScheduler( + SizeType32 maxNumRequests, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) + : BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState) + , mMaxNumRequests(maxNumRequests) +{ +} + std::tuple MaxRequestsScheduler::operator()(RequestList const& activeRequests) const { RequestVector scheduledRequests; for (auto const& req : activeRequests) { // if request cannot be scheduled yet or request should no longer be scheduled, skip - if (!req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState())) + if (shouldSkipRequest(req, false)) { continue; } @@ -183,6 +191,204 @@ std::tuple StaticBatchScheduler::operator()( return this->impl(kvCacheManager, crossKvCacheManager, peftCacheManager, activeRequests); } +std::tuple NonMixBatchingScheduler::operator()( + kv_cache_manager::BaseKVCacheManager const& kvCacheManager, + OptionalRef crossKvCacheManager, + OptionalRef peftCacheManager, RequestList const& activeRequests) const +{ + // Categorize requests into context and generation types + auto [contextRequests, generationRequests] = categorizeRequests(activeRequests); + + // Early return if no schedulable requests + if (contextRequests.empty() && generationRequests.empty()) + { + return {RequestVector{}, RequestVector{}}; + } + + // Determine scheduling strategy (context priority) + bool shouldScheduleContext = determineSchedulingStrategy(contextRequests, generationRequests); + + // Initialize KV cache block reuse optimization + bool skippingIsRelevant + = kvCacheManager.isEnableBlockReuse() || (crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()); + + // Prefill with blocks contributed by already executing chunked context requests + std::unordered_set newlyContributedContextBlocks; + std::unordered_set newlyContributedCrossContextBlocks; + + if (skippingIsRelevant) + { + std::tie(newlyContributedContextBlocks, newlyContributedCrossContextBlocks) + = prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager); + } + + // Schedule requests with resource management and optimization + RequestVector const& requestsToSchedule = shouldScheduleContext ? contextRequests : generationRequests; + RequestVector scheduledRequests + = scheduleRequests(requestsToSchedule, shouldScheduleContext, skippingIsRelevant, kvCacheManager, + crossKvCacheManager, peftCacheManager, newlyContributedContextBlocks, newlyContributedCrossContextBlocks); + + // Update state and logging + if (!scheduledRequests.empty()) + { + mLastWasContext = shouldScheduleContext; + TLLM_LOG_INFO("[NonMixBatchingScheduler-v2] Scheduled %lu %s requests, next will try %s", + scheduledRequests.size(), shouldScheduleContext ? "context/disagg-init" : "generation", + !mLastWasContext ? "context/disagg-init" : "generation"); + } + + return {std::move(scheduledRequests), RequestVector{}}; +} + +std::pair NonMixBatchingScheduler::categorizeRequests( + RequestList const& activeRequests) const +{ + RequestVector contextRequests; + RequestVector generationRequests; + + for (auto const& req : activeRequests) + { + // if request cannot be scheduled yet or request should no longer be scheduled, skip + if (shouldSkipRequest(req, true)) + { + continue; + } + + // Categorize requests by type + if (req->isContextInitState() || req->isDisaggGenerationInitState()) + { + contextRequests.emplace_back(req); + } + else if (req->isGenerationInProgressState()) + { + generationRequests.emplace_back(req); + } + } + + return {std::move(contextRequests), std::move(generationRequests)}; +} + +bool NonMixBatchingScheduler::determineSchedulingStrategy( + RequestVector const& contextRequests, RequestVector const& generationRequests) const +{ + // No requests available + if (contextRequests.empty() && generationRequests.empty()) + { + return false; + } + + // Only one type available - schedule whatever is available + if (contextRequests.empty()) + { + return false; + } + if (generationRequests.empty()) + { + return true; + } + + // Simply alternate: if last was context, schedule generation; if last was generation, schedule context + return !mLastWasContext; +} + +RequestVector NonMixBatchingScheduler::scheduleRequests(RequestVector const& requestsToSchedule, + bool shouldScheduleContext, bool skippingIsRelevant, kv_cache_manager::BaseKVCacheManager const& kvCacheManager, + OptionalRef crossKvCacheManager, + OptionalRef peftCacheManager, + std::unordered_set& newlyContributedContextBlocks, + std::unordered_set& newlyContributedCrossContextBlocks) const +{ + RequestVector scheduledRequests; + + // Resource management - KV cache blocks + auto reservedBlocks = kv_cache_manager::NoEvictScheduledBlocksManager(kvCacheManager); + auto reservedCrossBlocks = crossKvCacheManager + ? std::optional(kv_cache_manager::NoEvictScheduledBlocksManager(*crossKvCacheManager)) + : std::nullopt; + + // Resource management - PEFT cache + auto const maxPeftCachePages + = peftCacheManager ? peftCacheManager->getMaxDevicePages() : std::numeric_limits::max(); + SizeType32 claimedPeftPages{0}; + std::unordered_set uniqTaskIds{}; + + for (auto const& req : requestsToSchedule) + { + if (scheduledRequests.size() >= static_cast(mMaxNumRequests)) + { + break; + } + + // Generation requests: schedule directly without resource checks + // (consistent with GuaranteedNoEvictScheduler - they're already allocated) + if (!shouldScheduleContext) + { + scheduledRequests.emplace_back(req); + reservedBlocks.decrementReservedBlocks(*req); + if (reservedCrossBlocks) + { + reservedCrossBlocks->decrementReservedBlocks(*req); + } + // Handle PEFT for generation requests + bool const reqHasLora = req->getLoraTaskId().has_value(); + bool const isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value()); + if (isNewTask) + { + claimedPeftPages += peftCacheManager ? peftCacheManager->determineNumPages(req) : 0; + uniqTaskIds.insert(req->getLoraTaskId().value()); + } + } + else + { + // Context requests: check if it's beneficial to skip for block reuse + // (exclude DisaggGenerationInit, consistent with GuaranteedNoEvictScheduler) + if (skippingIsRelevant && !req->isDisaggGenerationInitState() + && beneficialToSkip(req, kvCacheManager, crossKvCacheManager, newlyContributedContextBlocks, + newlyContributedCrossContextBlocks)) + { + continue; + } + + // Context requests: check resource availability + bool enoughBlocks = reservedBlocks.enoughAvailableBlocks(*req); + bool enoughCrossBlocks = reservedCrossBlocks ? reservedCrossBlocks->enoughAvailableBlocks(*req) : true; + + // Check PEFT resource constraints + bool reqHasLora = req->getLoraTaskId().has_value(); + bool isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value()); + auto neededPeftPages = isNewTask && peftCacheManager ? peftCacheManager->determineNumPages(req) : 0; + auto availablePeftPages = maxPeftCachePages - claimedPeftPages; + + // Schedule context request only if all resources are available + if (enoughBlocks && enoughCrossBlocks && (isNewTask ? neededPeftPages <= availablePeftPages : true)) + { + scheduledRequests.emplace_back(req); + + // Synchronously update all resource states + reservedBlocks.decrementReservedBlocks(*req); + if (reservedCrossBlocks) + { + reservedCrossBlocks->decrementReservedBlocks(*req); + } + + // Update PEFT resource state + if (isNewTask) + { + claimedPeftPages += neededPeftPages; + uniqTaskIds.insert(req->getLoraTaskId().value()); + } + } + else if (!enoughBlocks || !enoughCrossBlocks) + { + // Resource insufficient, stop scheduling more context requests + break; + } + } + } + + return scheduledRequests; +} + std::tuple GuaranteedNoEvictScheduler::operator()( kv_cache_manager::BaseKVCacheManager const& kvCacheManager, OptionalRef crossKvCacheManager, @@ -235,10 +441,7 @@ std::tuple GuaranteedNoEvictScheduler::impl( for (auto const& req : activeRequests) { // if request cannot be scheduled yet or request should no longer be scheduled, skip - if ( - // Allow disagg_generation_init requests to be scheduled, so that we'll allocate their KV cache - !req->isDisaggGenerationInitState() - && (!req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState()))) + if (shouldSkipRequest(req, true)) { continue; } @@ -272,6 +475,10 @@ std::tuple GuaranteedNoEvictScheduler::impl( } } + TLLM_LOG_INFO( + "GuaranteedNoEvictScheduler: generation requests scheduledRequests.size() = %lu, " + "pendingDisGenInitRequests.size() = %lu, pendingRequests.size() = %lu", + scheduledRequests.size(), pendingDisGenInitRequests.size(), pendingRequests.size()); // If StaticBatchScheduling == true check if we can add pending requests only when no requests are active. // Otherwise, add just check that we can add pending requests. if (!StaticBatchScheduling || scheduledRequests.size() == 0) @@ -327,7 +534,10 @@ std::tuple GuaranteedNoEvictScheduler::impl( } } } + TLLM_LOG_INFO( + "GuaranteedNoEvictScheduler: context requests scheduledRequests.size() = %lu", scheduledRequests.size()); } + TLLM_LOG_INFO("GuaranteedNoEvictScheduler: final scheduledRequests.size() = %lu", scheduledRequests.size()); return {std::move(scheduledRequests), RequestVector{}}; } @@ -373,10 +583,7 @@ std::tuple MaxUtilizationScheduler::operator()( TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduling request ID %lu", req->mRequestId); // if request cannot be scheduled yet or request should no longer be scheduled, skip - if ( - // Allow disagg_generation_init requests to be scheduled, so that we'll allocate their KV cache - !req->isDisaggGenerationInitState() - && (!req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState()))) + if (shouldSkipRequest(req, true)) { TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu cannot / should not be scheduled", req->mRequestId); reqIt++; @@ -479,6 +686,10 @@ CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests, { mScheduler = StaticBatchScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState}; } + else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kNON_MIX_BATCHING) + { + mScheduler = NonMixBatchingScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState}; + } else { throw std::runtime_error("Unsupported capacity scheduler policy"); @@ -507,7 +718,8 @@ std::tuple CapacityScheduler::opera = scheduler(*kvCacheManager, peftCacheManager, activeRequests); } else if constexpr (std::is_same_v, GuaranteedNoEvictScheduler> - || std::is_same_v, StaticBatchScheduler>) + || std::is_same_v, StaticBatchScheduler> + || std::is_same_v, NonMixBatchingScheduler>) { std::tie(tmpFittingRequests, pausedRequests) = scheduler(*kvCacheManager, crossKvCacheManager, peftCacheManager, activeRequests); diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp index 388af63cac8..93d4638042a 100644 --- a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp @@ -90,7 +90,8 @@ void initBindings(nb::module_& m) nb::enum_(m, "CapacitySchedulerPolicy") .value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION) .value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) - .value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH); + .value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH) + .value("NON_MIX_BATCHING", tle::CapacitySchedulerPolicy::kNON_MIX_BATCHING); nb::enum_(m, "ContextChunkingPolicy") .value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS) diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index e3d9d6c1c6a..8755e3cbc10 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -90,7 +90,8 @@ void initBindings(pybind11::module_& m) py::enum_(m, "CapacitySchedulerPolicy") .value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION) .value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) - .value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH); + .value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH) + .value("NON_MIX_BATCHING", tle::CapacitySchedulerPolicy::kNON_MIX_BATCHING); py::enum_(m, "ContextChunkingPolicy") .value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS) diff --git a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp index ba611e8720a..570345aa5ef 100644 --- a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp @@ -17,6 +17,7 @@ #include #include +#include #include "tensorrt_llm/batch_manager/capacityScheduler.h" #include "tensorrt_llm/batch_manager/common.h" @@ -1885,3 +1886,149 @@ TEST_F(CapacitySchedulerTest, SimpleFitsStaticBatch) EXPECT_EQ(numIterations, 160); } } + +TEST_F(CapacitySchedulerTest, NonMixBatchingSchedulerNoMixing) +{ + SizeType32 kvCacheMaxNumTokens = 2000; // Increased for more capacity + SizeType32 kvCacheTokensPerBlock = 10; + SizeType32 kvCacheMaxNumTokensPerSeq = 200; // Increased per sequence + SizeType32 maxNumRequests = 12; // Increased to handle more requests + SizeType32 maxInputLen = 1000; + + auto kvCacheManager + = getKvCacheManager(maxNumRequests, kvCacheTokensPerBlock, kvCacheMaxNumTokens, kvCacheMaxNumTokensPerSeq, 0); + auto peftCacheManager = getPeftCacheManager(); + CapacitySchedulerPolicy capacitySchedulerPolicy = CapacitySchedulerPolicy::kNON_MIX_BATCHING; + auto capacityScheduler = CapacityScheduler(maxNumRequests, capacitySchedulerPolicy, kvCacheManager != nullptr); + + int32_t maxNewTokens = 5; // Reduced to slow down completion + int32_t promptLen = 10; + + // Create a larger pool of requests - ALL start as context requests + RequestList activeRequests; + RequestList pendingRequests; // Requests to add later + + // Add initial 4 context requests + for (int i = 0; i < 4; ++i) + { + activeRequests.push_back(createRequest(promptLen, maxNewTokens, i, std::nullopt, + tensorrt_llm::executor::Request::kDefaultPriority, LlmRequestState::kCONTEXT_INIT)); + } + + // Prepare additional context requests to add during iterations + for (int i = 4; i < 10; ++i) + { + pendingRequests.push_back(createRequest(promptLen, maxNewTokens, i, std::nullopt, + tensorrt_llm::executor::Request::kDefaultPriority, LlmRequestState::kCONTEXT_INIT)); + } + + // Test multiple scheduling iterations to verify alternating behavior + std::vector schedulingPattern; + int generationStepsRemaining = 0; // Track how many generation steps each request needs + std::map requestGenerationSteps; // Track generation steps per request + + for (int iteration = 0; iteration < 12 && (!activeRequests.empty() || !pendingRequests.empty()); ++iteration) + { + // Add new requests periodically to maintain diversity + if (iteration % 3 == 0 && !pendingRequests.empty()) + { + activeRequests.push_back(pendingRequests.front()); + pendingRequests.pop_front(); + } + + // Skip if no active requests + if (activeRequests.empty()) + { + continue; + } + + auto [scheduledRequests, scheduledDisaggGenInitRequests, pausedRequests] + = capacityScheduler(activeRequests, kvCacheManager, peftCacheManager, std::nullopt); + + EXPECT_FALSE(scheduledRequests.empty()) << "Should schedule some requests in iteration " << iteration; + + // Verify batch purity: all requests in the same batch should have the same state type + bool hasContextRequests = false; + bool hasGenerationRequests = false; + + for (auto const& req : scheduledRequests) + { + if (req->isContextInitState() || req->isDisaggGenerationInitState()) + { + hasContextRequests = true; + } + else if (req->isGenerationInProgressState()) + { + hasGenerationRequests = true; + } + } + + // The key test: ensure no mixing in the same batch + EXPECT_FALSE(hasContextRequests && hasGenerationRequests) + << "Batch mixing detected in iteration " << iteration + << " - found both context and generation requests in the same batch"; + + // Record the scheduling pattern for alternation verification + if (hasContextRequests) + { + schedulingPattern.push_back("CONTEXT"); + } + else if (hasGenerationRequests) + { + schedulingPattern.push_back("GENERATION"); + } + + // Simulate request progression - slower completion + for (auto& req : scheduledRequests) + { + if (req->isContextInitState()) + { + // Transition context to generation + req->setState(LlmRequestState::kGENERATION_IN_PROGRESS); + req->setContextCurrentPosition(promptLen); + req->setDecodingIter(1); + // Initialize generation steps needed + requestGenerationSteps[req->mRequestId] = maxNewTokens; + } + else if (req->isGenerationInProgressState()) + { + // Progress generation request by one step + auto reqId = req->mRequestId; + if (requestGenerationSteps.find(reqId) == requestGenerationSteps.end()) + { + requestGenerationSteps[reqId] = maxNewTokens; + } + + requestGenerationSteps[reqId]--; + req->setDecodingIter(req->getDecodingIter() + 1); + + // Complete only if all steps are done + if (requestGenerationSteps[reqId] <= 0) + { + req->setState(LlmRequestState::kGENERATION_COMPLETE); + } + } + } + + // Remove only completed requests + activeRequests.erase( + std::remove_if(activeRequests.begin(), activeRequests.end(), + [](auto const& req) { return req->getState() == LlmRequestState::kGENERATION_COMPLETE; }), + activeRequests.end()); + } + + // Verify scheduling pattern shows correct non-mixing behavior + EXPECT_FALSE(schedulingPattern.empty()) << "Should have recorded some scheduling pattern"; + + // Verify first batch should be CONTEXT (context priority) + EXPECT_EQ(schedulingPattern[0], "CONTEXT") << "First batch should be context requests due to context priority"; + + // Verify we have both types in the pattern (comprehensive test) + bool hasContext + = std::find(schedulingPattern.begin(), schedulingPattern.end(), "CONTEXT") != schedulingPattern.end(); + bool hasGeneration + = std::find(schedulingPattern.begin(), schedulingPattern.end(), "GENERATION") != schedulingPattern.end(); + + EXPECT_TRUE(hasContext) << "Scheduling pattern should include CONTEXT batches"; + EXPECT_TRUE(hasGeneration) << "Scheduling pattern should include GENERATION batches"; +} diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index dae057b61ea..6610e1090f7 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -11,12 +11,10 @@ from tensorrt_llm._utils import (confidential_compute_enabled, str_dtype_to_binding, torch_dtype_to_str) from tensorrt_llm.bindings.executor import DecodingMode -from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig, - EagleDecodingConfig, KvCacheConfig, - MTPDecodingConfig, PeftCacheConfig, - SamplerType, SchedulerConfig, - SparseAttentionConfig, - SpeculativeConfig, TorchLlmArgs) +from tensorrt_llm.llmapi.llm_args import ( + CacheTransceiverConfig, CapacitySchedulerPolicy, EagleDecodingConfig, + KvCacheConfig, MTPDecodingConfig, PeftCacheConfig, SamplerType, + SchedulerConfig, SparseAttentionConfig, SpeculativeConfig, TorchLlmArgs) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules) @@ -852,6 +850,11 @@ def create_py_executor_instance( if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager: scheduler_capacity += 1 + # Set the capacity scheduler policy to NON_MIX_BATCHING for CI testing + scheduler_config.capacity_scheduler_policy = CapacitySchedulerPolicy.NON_MIX_BATCHING + logger.info( + f"scheduler_config.capacity_scheduler_policy: {scheduler_config.capacity_scheduler_policy}" + ) capacity_scheduler = BindCapacityScheduler( scheduler_capacity, kv_cache_manager.impl if kv_cache_manager is not None else None, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index f186da6cd89..f535e42982b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1893,9 +1893,18 @@ def _prepare_tp_inputs( cache_indirection_buffer: Optional[torch.Tensor] = None, num_accepted_tokens_device: Optional[torch.Tensor] = None, req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None, - resource_manager: Optional[ResourceManager] = None): + resource_manager: Optional[ResourceManager] = None, + previous_batch_seq_slots: Optional[set] = None): """ Prepare inputs for Pytorch Model. + + Args: + previous_batch_seq_slots: Set of seq_slots that were in the previous batch. + Used to adjust position_id for generation requests when overlap scheduler + is enabled with NON_MIX_BATCHING. If a generation request was NOT in the + previous batch (e.g., the previous batch was context-only), the token in + previous_tensors_device is from an earlier batch, but max_beam_num_tokens + was already updated. We need to subtract 1 to match the token position. """ new_tokens_device, new_tokens_lens_device, next_draft_tokens_device = None, None, None @@ -2249,7 +2258,17 @@ def _prepare_tp_inputs( first_beam = 0 if beam == first_beam: previous_batch_indices.append(request.py_batch_idx) - past_seen_token_num = request.max_beam_num_tokens + # For NON_MIX_BATCHING with overlap scheduler: if the request was NOT + # in the previous batch, the token in previous_tensors_device is from + # an earlier batch. However, max_beam_num_tokens reflects the update + # from that earlier batch's _update_requests (which incremented it). + # We need to use max_beam_num_tokens - 1 to match the token's position. + if (previous_batch_seq_slots is not None + and request.py_seq_slot + not in previous_batch_seq_slots): + past_seen_token_num = request.max_beam_num_tokens - 1 + else: + past_seen_token_num = request.max_beam_num_tokens position_id = past_seen_token_num if self.mapping.has_cp_helix(): @@ -3202,7 +3221,8 @@ def _prepare_inputs( cache_indirection_buffer: Optional[torch.Tensor] = None, num_accepted_tokens_device: Optional[torch.Tensor] = None, req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None, - resource_manager: Optional[ResourceManager] = None): + resource_manager: Optional[ResourceManager] = None, + previous_batch_seq_slots: Optional[set] = None): if self.mapping is not None and 'cp_type' in self.mapping.cp_config: cp_type = self.mapping.cp_config['cp_type'] if CpType.STAR == cp_type: @@ -3215,12 +3235,11 @@ def _prepare_inputs( raise NotImplementedError( f"Unsupported cp_type {getattr(cp_type, 'name', cp_type)}.") - return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager, - attn_metadata, spec_metadata, - new_tensors_device, - cache_indirection_buffer, - num_accepted_tokens_device, - req_id_to_old_request, resource_manager) + return self._prepare_tp_inputs( + scheduled_requests, kv_cache_manager, attn_metadata, spec_metadata, + new_tensors_device, cache_indirection_buffer, + num_accepted_tokens_device, req_id_to_old_request, resource_manager, + previous_batch_seq_slots) @torch.inference_mode() @with_model_extra_attrs(lambda self: self.model.extra_attrs) @@ -3232,7 +3251,8 @@ def forward(self, cache_indirection_buffer: Optional[torch.Tensor] = None, spec_decoding_tensor: Optional[SpecDecodingTensor] = None, num_accepted_tokens_device: Optional[torch.Tensor] = None, - req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None): + req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None, + previous_batch_seq_slots: Optional[set] = None): kv_cache_manager = resource_manager.get_resource_manager( self.kv_cache_manager_key) @@ -3309,7 +3329,7 @@ def forward(self, padded_requests, kv_cache_manager, attn_metadata, spec_metadata, new_tensors_device, cache_indirection_buffer, num_accepted_tokens_device, req_id_to_old_request, - resource_manager) + resource_manager, previous_batch_seq_slots) with with_shared_pool(self.cuda_graph_runner.get_graph_pool()): if not can_run_graph: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8f54aca6c48..03efa8b851c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -31,7 +31,8 @@ StaticBatchingStats) from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, ReqIdsSet) -from tensorrt_llm.llmapi.llm_args import PeftCacheConfig +from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy, + PeftCacheConfig) from tensorrt_llm.logger import logger from tensorrt_llm.mapping import CpType from tensorrt_llm.runtime.generation import CUASSERT @@ -1655,9 +1656,24 @@ def _executor_loop_overlap(self): else: previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device + # Extract seq_slots from the previous batch to detect if generation requests + # were in the previous batch. This is important for NON_MIX_BATCHING: + # if a generation request wasn't in the previous batch (e.g., it was context-only), + # the position_id calculation needs adjustment. + previous_batch_seq_slots = None + if self.scheduler_config.capacity_scheduler_policy == CapacitySchedulerPolicy.NON_MIX_BATCHING and self.previous_batch and self.previous_batch.sample_state: + prev_scheduled = self.previous_batch.sample_state.scheduled_requests + previous_batch_seq_slots = set() + for req in prev_scheduled.context_requests: + if req.py_seq_slot is not None: + previous_batch_seq_slots.add(req.py_seq_slot) + for req in prev_scheduled.generation_requests: + if req.py_seq_slot is not None: + previous_batch_seq_slots.add(req.py_seq_slot) + batch_outputs = self._forward_step( scheduled_batch, previous_tensors_device, - num_accepted_tokens_device) + num_accepted_tokens_device, previous_batch_seq_slots) if self.previous_batch is not None: self._update_requests(self.previous_batch.sample_state) @@ -2225,11 +2241,11 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0): self.kv_cache_transceiver.check_gen_transfer_status(atLeastNum) self._check_cache_transfer_errors("generation requests") - def _forward_step( - self, - scheduled_requests, - new_tensors_device: Optional[SampleStateTensors] = None, - num_accepted_tokens_device: Optional[torch.Tensor] = None): + def _forward_step(self, + scheduled_requests, + new_tensors_device: Optional[SampleStateTensors] = None, + num_accepted_tokens_device: Optional[torch.Tensor] = None, + previous_batch_seq_slots: Optional[set] = None): ExpertStatistic.set_iter(self.iter_counter) @nvtx_range( @@ -2237,14 +2253,15 @@ def _forward_step( ) def forward(scheduled_requests, resource_manager, new_tensors_device, gather_context_logits, cache_indirection_buffer, - num_accepted_tokens_device): + num_accepted_tokens_device, previous_batch_seq_slots): return self.model_engine.forward( scheduled_requests, resource_manager, new_tensors_device, gather_context_logits=gather_context_logits, cache_indirection_buffer=cache_indirection_buffer, - num_accepted_tokens_device=num_accepted_tokens_device) + num_accepted_tokens_device=num_accepted_tokens_device, + previous_batch_seq_slots=previous_batch_seq_slots) try: gather_context_logits = any( @@ -2259,7 +2276,8 @@ def forward(scheduled_requests, resource_manager, new_tensors_device, outputs = forward(scheduled_requests, self.resource_manager, new_tensors_device, gather_context_logits, cache_indirection_buffer, - num_accepted_tokens_device) + num_accepted_tokens_device, + previous_batch_seq_slots) # Ensure the default stream waits for execution_stream to complete # before downstream operations use the outputs. diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 2c1d8f916f5..7e463c23267 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -148,6 +148,7 @@ def __init__( super(BindCapacityScheduler, self).__init__() self.kv_cache_manager = kv_cache_manager self.peft_cache_manager = peft_cache_manager + self.scheduler_policy = scheduler_policy self.impl = tb_internal.algorithms.CapacityScheduler( max_num_requests=max_num_requests, @@ -275,6 +276,7 @@ def schedule_request(self, active_requests: RequestList, context_requests, generation_requests = self.micro_batch_scheduler.schedule( fitting_requests, inflight_request_ids) + # Convert from binding type RequestVector to list[LlmRequest], # so Python fields on LlmRequest won't be stripped away return SchedulerOutput(list(context_requests), diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 3f15252b84f..ab7001a5a90 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1391,6 +1391,7 @@ class CapacitySchedulerPolicy(StrEnum, metaclass=PybindMirrorEnumMeta): MAX_UTILIZATION = "MAX_UTILIZATION" GUARANTEED_NO_EVICT = "GUARANTEED_NO_EVICT" STATIC_BATCH = "STATIC_BATCH" + NON_MIX_BATCHING = "NON_MIX_BATCHING" def _to_pybind(self): return getattr(_CapacitySchedulerPolicy, self.value) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ee416eb247e..38058324f72 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1342,6 +1342,39 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): MODEL_NAME = "deepseek-ai/DeepSeek-V3-Lite" MODEL_PATH = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16" + @parametrize_with_ids("max_batch_size", + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + @parametrize_with_ids("max_num_tokens", + [1024, 2048, 4096, 8192, 16384, 32768]) + def test_cuda_graph_update(self, max_batch_size, max_num_tokens): + kv_cache_config = KvCacheConfig(dtype="fp8", + enable_block_reuse=False, + free_gpu_memory_fraction=0.75) + cuda_graph_config = CudaGraphConfig(enable_padding=True) + from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy, + SchedulerConfig) + pytorch_config = dict( + attn_backend="TRTLLM", + moe_config=MoeConfig(backend="DEEPGEMM"), + print_iter_log=True, + enable_attention_dp=False, + max_batch_size=max_batch_size, + disable_overlap_scheduler=False, + scheduler_config=SchedulerConfig( + capacity_scheduler_policy=CapacitySchedulerPolicy. + NON_MIX_BATCHING, ), + ) + + with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8", + kv_cache_config=kv_cache_config, + cuda_graph_config=cuda_graph_config, + **pytorch_config) as llm: + + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES + + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device_memory(60000) # Chunked Prefill for MLA can only be enabled on SM100 @parametrize_with_ids("enable_chunked_prefill", [False, True])