Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ namespace tensorrt_llm::batch_manager
namespace kv_cache_manager
{
class BaseKVCacheManager;
}
struct BlockKey;
struct BlockKeyHasher;
} // namespace kv_cache_manager
class BasePeftCacheManager;
} // namespace tensorrt_llm::batch_manager

Expand Down Expand Up @@ -60,6 +62,29 @@ class BaseCapacityScheduler
return mNoScheduleAfterState;
}

/// @brief Check if a request should be skipped from scheduling
/// @param req The request to check
/// @param allowDisaggGeneration Allow disagg_generation_init requests to be scheduled
/// @return true if the request should be skipped, false otherwise
[[nodiscard]] bool shouldSkipRequest(
std::shared_ptr<LlmRequest> const& req, bool allowDisaggGeneration = true) const
{
// Check basic scheduling conditions first
bool basicSkipCondition
= !req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState());

if (allowDisaggGeneration)
{
// Allow disagg_generation_init requests to be scheduled, so they're excluded from the disagg check
return !req->isDisaggGenerationInitState() && basicSkipCondition;
}
else
{
// Don't make exception for disagg_generation_init requests
return basicSkipCondition;
}
}

private:
/// The state until/after which the scheduler should not schedule requests
LlmRequestState mNoScheduleUntilState;
Expand Down Expand Up @@ -140,6 +165,49 @@ class StaticBatchScheduler : public GuaranteedNoEvictScheduler
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
};

/// @brief Schedule requests with non-mix batching strategy
/// @details Ensures that each batch contains only context requests OR only generation requests.
/// Prevents mixing of request types in batches, with context having priority.
class NonMixBatchingScheduler : public BaseCapacityScheduler
{
public:
NonMixBatchingScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);

[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;

private:
SizeType32 mMaxNumRequests;
/// @brief Records whether the last scheduled batch was context phase (true) or generation phase (false)
/// Initialize to false to prioritize context requests in the first scheduling cycle
mutable bool mLastWasContext{false};

/// @brief Categorize active requests into context and generation types
[[nodiscard]] std::pair<RequestVector, RequestVector> categorizeRequests(RequestList const& activeRequests) const;

/// @brief Determine which request type to schedule based on availability and alternating policy
[[nodiscard]] bool determineSchedulingStrategy(
RequestVector const& contextRequests, RequestVector const& generationRequests) const;

/// @brief Schedule requests from the selected request type with resource management
[[nodiscard]] RequestVector scheduleRequests(RequestVector const& requestsToSchedule, bool shouldScheduleContext,
bool skippingIsRelevant, kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager,
std::unordered_set<kv_cache_manager::BlockKey, kv_cache_manager::BlockKeyHasher>& newlyContributedContextBlocks,
std::unordered_set<kv_cache_manager::BlockKey, kv_cache_manager::BlockKeyHasher>&
newlyContributedCrossContextBlocks) const;

/// @brief Check if PEFT resources are available for the request
[[nodiscard]] bool checkPeftResources(std::shared_ptr<LlmRequest> const& req,
OptionalRef<BasePeftCacheManager const> peftCacheManager, SizeType32 maxPeftCachePages,
SizeType32& claimedPeftPages, std::unordered_set<uint64_t>& uniqTaskIds) const;
};

class CapacityScheduler : public Algorithm
{
public:
Expand Down Expand Up @@ -169,7 +237,7 @@ class CapacityScheduler : public Algorithm

private:
std::variant<std::monostate, MaxRequestsScheduler, MaxUtilizationScheduler, GuaranteedNoEvictScheduler,
StaticBatchScheduler>
StaticBatchScheduler, NonMixBatchingScheduler>
mScheduler;
};

Expand Down
6 changes: 5 additions & 1 deletion cpp/include/tensorrt_llm/executor/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,11 @@ enum class CapacitySchedulerPolicy

/// @brief kSTATIC_BATCH does not schedule new requests until all requests in current batch are completed.
/// Similar to kGUARANTEED_NO_EVICT, requests will run to completion without eviction.
kSTATIC_BATCH = 2
kSTATIC_BATCH = 2,

/// @brief kNON_MIX_BATCHING ensures each batch contains only context OR only generation requests.
/// Prevents mixing of context and generation requests in the same batch, with context having priority.
kNON_MIX_BATCHING = 3
};

std::ostream& operator<<(std::ostream& os, CapacitySchedulerPolicy policy);
Expand Down
Loading
Loading