From 64bce0f181b8b9db69a85f4f0cb751c9d0ca6576 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 13:37:59 +0800 Subject: [PATCH 01/25] re-implement micro batch scheduler and capacity scheduler in python Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 551 ++++++++++++++++++-- 1 file changed, 496 insertions(+), 55 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 2c1d8f916f5..ae8dc923fb2 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -1,7 +1,9 @@ +import dataclasses from abc import ABC, abstractmethod from collections import namedtuple from dataclasses import dataclass -from typing import Optional, Tuple +from enum import Enum +from typing import Optional, Set, Tuple from strenum import StrEnum @@ -164,60 +166,6 @@ def schedule_request( self.peft_cache_manager) -class GuaranteedNoEvictScheduler(CapacityScheduler): - # only schedule requests has no_schedule_until_state <= state < no_schedule_after_state - no_schedule_until_state = LlmRequestState.CONTEXT_INIT - no_schedule_after_state = LlmRequestState.GENERATION_COMPLETE - - def __init__(self, max_num_requests: int, kv_cache_manager): - super(GuaranteedNoEvictScheduler, self).__init__() - self.max_num_requests = max_num_requests - self.kv_cache_manager = kv_cache_manager - - def schedule_request( - self, active_requests: RequestList - ) -> tuple[list[LlmRequest], list[LlmRequest]]: - scheduled_requests = [] - pending_requests = [] - reserved_blocks = 0 - max_blocks = self.kv_cache_manager.get_max_resource_count() - for request in active_requests: - req_state = request.state - # if request cannot be scheduled yet or request should no longer be scheduled, skip - if req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value: - continue - - if len(scheduled_requests - ) >= self.max_num_requests or reserved_blocks >= max_blocks: - break - elif req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE: - scheduled_requests.append(request) - reserved_blocks += self.kv_cache_manager.get_needed_resource_to_completion( - request) - else: - pending_requests.append(request) - - avaiable_blocks = max_blocks - reserved_blocks - for request in pending_requests: - req_state = request.state - if len(scheduled_requests) >= self.max_num_requests: - break - elif req_state == LlmRequestState.CONTEXT_INIT: - needed_blocks = self.kv_cache_manager.get_needed_resource_to_completion( - request) - if needed_blocks <= avaiable_blocks: - scheduled_requests.append(request) - avaiable_blocks -= needed_blocks - elif needed_blocks > avaiable_blocks: - # If one requests fails to be scheduled, break - break - - assert len(scheduled_requests) > 0, ( - "no pending request can get enough resource to complete, " - "please increase KV cache pool size.") - return scheduled_requests, [] - - class MicroBatchScheduler(ABC): @abstractmethod @@ -286,3 +234,496 @@ def can_schedule(self, requests: RequestList) -> bool: fitting_requests, _, _ = self.capacity_scheduler.schedule_request( requests) return len(fitting_requests) == len(requests) + + +class ChunkingPolicy(Enum): + EQUAL_PROGRESS = 1 + FIRST_COME_FIRST_SERVED = 2 + + +@dataclasses.dataclass +class ContextChunkingConfig: + chunking_policy: ChunkingPolicy + chunk_unit_size: int + + +class PyMicroBatchScheduler(MicroBatchScheduler): + + def __init__( + self, + max_batch_size: int, + max_num_tokens: Optional[int] = None, + ctx_chunk_config: Optional[ContextChunkingConfig] = None, + ): + super().__init__() + self.max_batch_size = max_batch_size + self.max_num_tokens = max_num_tokens + self.ctx_chunk_config = ctx_chunk_config + + def schedule( + self, active_requests: RequestList, + inflight_request_ids: Set[int]) -> Tuple[RequestList, RequestList]: + + context_requests: RequestList = [] + generation_requests: RequestList = [] + + current_batch_tokens = 0 + scheduled_req_count = 0 + scheduled_beam_width = 0 + + contexts_to_be_chunked: RequestList = [] + num_chunked_tokens = 0 + all_context_fits = True + + # 1. First Pass: Filter & Categorize (Generation First) + for req in active_requests: + # Skip invalid states (Simplified check, assuming caller filters mostly) + if req.request_id in inflight_request_ids: + continue + + # --- Generation Handling --- + if req.state == LlmRequestState.GENERATION_IN_PROGRESS: + beam_width = req.sampling_config.beam_width + req_num_tokens = beam_width + req.num_draft_tokens + + # Check Global Token Budget + if self.max_num_tokens is not None and (current_batch_tokens + + req_num_tokens + > self.max_num_tokens): + break + + # Check Beam Width Consistency (Batch constraint) + if scheduled_beam_width == 0: + scheduled_beam_width = beam_width + elif scheduled_beam_width != beam_width: + continue + + generation_requests.append(req) + current_batch_tokens += req_num_tokens + + # --- Context Handling --- + elif req.state == LlmRequestState.CONTEXT_INIT: + if not self.ctx_chunk_config: + # No Chunking: Greedy allocation + req_num_tokens = req.get_context_remaining_length() + draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 + total_tokens = req_num_tokens + draft_tokens + + if self.max_num_tokens is not None and ( + current_batch_tokens + total_tokens + > self.max_num_tokens): + break + + context_requests.append(req) + current_batch_tokens += total_tokens + else: + # Chunking Enabled: Defer calculation + remaining = req.get_context_remaining_length() + # Just an estimate for budget check + req.context_chunk_size = remaining + + draft_tokens = req.num_draft_tokens if ( + req.is_last_context_chunk + and req.has_draft_tokens) else 0 + req_num_tokens = remaining + draft_tokens + + contexts_to_be_chunked.append(req) + num_chunked_tokens += req_num_tokens + + # Batch Size Check + scheduled_req_count += 1 + if scheduled_req_count >= self.max_batch_size: + break + + # 2. Check if chunking logic is needed + if self.max_num_tokens is not None and num_chunked_tokens > ( + self.max_num_tokens - current_batch_tokens): + all_context_fits = False + + # 3. Apply Chunking Strategy + if not all_context_fits and contexts_to_be_chunked: + if not self.ctx_chunk_config: + # Should effectively be handled above, but as a fallback + pass + else: + remaining_capacity = ( + self.max_num_tokens - current_batch_tokens + ) if self.max_num_tokens is not None else None + self._set_ctx_requests_chunk_size(contexts_to_be_chunked, + remaining_capacity) + + # 4. Finalize Context Requests + for req in contexts_to_be_chunked: + if req.context_chunk_size > 0: + context_requests.append(req) + current_batch_tokens += req.context_chunk_size + + return context_requests, generation_requests + + def _set_ctx_requests_chunk_size(self, requests: RequestList, + capacity: Optional[int]): + # Reset + for req in requests: + req.context_chunk_size = 0 + + policy = self.ctx_chunk_config.chunking_policy + unit_size = self.ctx_chunk_config.chunk_unit_size + + if policy == ChunkingPolicy.EQUAL_PROGRESS: + self._chunk_equal_progress(requests, capacity, unit_size) + elif policy == ChunkingPolicy.FIRST_COME_FIRST_SERVED: + self._chunk_fcfs(requests, capacity, unit_size) + + # Optimization: Fit draft tokens if space allows + self._fit_draft_tokens(requests, capacity, unit_size) + + def _chunk_equal_progress(self, requests, capacity, unit_size): + num_ctx_tokens = 0 + made_progress = True + + while (capacity is None or num_ctx_tokens < capacity) and made_progress: + made_progress = False + for req in requests: + past_size = req.context_chunk_size + remaining = req.get_context_remaining_length() + + if past_size >= remaining: + continue + + suggested_size = past_size + unit_size + actual_size = min(suggested_size, remaining) + increment = actual_size - past_size + + if increment > 0: + if capacity is not None and (num_ctx_tokens + increment + > capacity): + # Cannot fit this increment, stop growing this request + req.context_chunk_size = past_size + continue + + req.context_chunk_size = actual_size + num_ctx_tokens += increment + made_progress = True + + def _chunk_fcfs(self, requests, capacity, unit_size): + current_capacity = capacity if capacity is not None else float('inf') + + for req in requests: + remaining = req.get_context_remaining_length() + actual_size = remaining + + if current_capacity < actual_size: + actual_size = current_capacity + + # Align if truncated + if actual_size < remaining: + actual_size = (int(actual_size) // unit_size) * unit_size + + req.context_chunk_size = int(actual_size) + current_capacity -= req.context_chunk_size + + if current_capacity <= 0: + break + + def _fit_draft_tokens(self, requests, capacity, unit_size): + # Python port of fitDraftTokens + # Logic: If it is the last chunk, try to fit draft tokens without using a new KV block + current_tokens = sum(r.context_chunk_size for r in requests) + + for req in requests: + if req.is_last_context_chunk and req.has_draft_tokens: + chunk_size = req.context_chunk_size + remainder = chunk_size % unit_size + # Space left in the last block + space_in_block = 0 if remainder == 0 else (unit_size - + remainder) + + # Check constraints + allowed_space = space_in_block + if capacity is not None: + allowed_space = min(allowed_space, + capacity - current_tokens) + + # If we can't fit all draft tokens in the existing block/capacity, discard them + draft_needed = req.num_draft_tokens + if draft_needed > allowed_space: + # In python we might need a method to discard/update draft tokens on req + # req.discard_draft_tokens(draft_needed - allowed_space) + pass + else: + current_tokens += draft_needed + + +class PyCapacityScheduler(CapacityScheduler): + + def __init__( + self, + max_num_requests: int, + kv_cache_manager, + scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. + MAX_UTILIZATION, + no_schedule_until_state=LlmRequestState.CONTEXT_INIT, + no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE, + ): + super().__init__() + self.max_num_requests = max_num_requests + self.kv_cache_manager = kv_cache_manager + self.policy = scheduler_policy + self.no_schedule_until_state = no_schedule_until_state + self.no_schedule_after_state = no_schedule_after_state + + def schedule_request( + self, active_requests: RequestList + ) -> Tuple[RequestList, RequestList, RequestList]: + + if self.policy == CapacitySchedulerPolicy.MAX_UTILIZATION: + return self._schedule_max_utilization(active_requests) + elif self.policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: + # Reuse existing implementation logic or simple pass-through + return self._schedule_guaranteed_no_evict(active_requests) + else: + raise NotImplementedError( + f"Policy {self.policy} not implemented in PyCapacityScheduler") + + def _schedule_max_utilization(self, active_requests: RequestList): + scheduled_requests = [] + paused_requests = [] + + # We need to simulate the C++ "BlockManager" state + # Since Phase 1 uses C++ Manager, we assume we call it. + # But C++ `startScheduling()` resets internal temp state. + if hasattr(self.kv_cache_manager, "start_scheduling"): + self.kv_cache_manager.start_scheduling() + + # Iterate through all requests + # Logic: Try to schedule. If fail, see if we can pause a running request to make room. + + iter(active_requests) + cached_active_list = list(active_requests) # For reverse lookups + + idx = 0 + while idx < len(cached_active_list): + req = cached_active_list[idx] + + # 1. State Filter + if (req.state.value < self.no_schedule_until_state.value + or req.state.value >= self.no_schedule_after_state.value): + # Cannot schedule, but keep iterating + idx += 1 + continue + + # 2. Max Requests Limit + if len(scheduled_requests) >= self.max_num_requests: + break + + # 3. KV Cache Check (The Critical Part) + # We assume KV Manager has a `check_allocation` or `prepare_blocks` method + # that returns needed blocks or None if it doesn't fit. + # In C++ this is `blocksManager.prepareNewNumberOfBlocks...` + + # NOTE: For Phase 1, we might need to expose a helper in C++ binding + # if direct block math isn't exposed. + # Assuming `kv_cache_manager.check_and_update_allocation(req)` returns True/False + + can_allocate = False + try: + # This function implies C++ side logic: "If I add this req, do I have blocks?" + # It updates the internal transaction state of C++ manager. + can_allocate = self.kv_cache_manager.check_and_update_allocation( + req) + except AttributeError: + # Fallback / Mock for logic understanding + can_allocate = True + + if can_allocate: + scheduled_requests.append(req) + idx += 1 + else: + # 4. Backtracking / Pausing Logic + # If we failed to allocate, can we pause a *previously scheduled* Running request? + # Find the last scheduled request that is in GENERATION phase + victim_idx = -1 + for i in range(len(scheduled_requests) - 1, -1, -1): + r = scheduled_requests[i] + if r.state == LlmRequestState.GENERATION_IN_PROGRESS: + victim_idx = i + break + + if victim_idx != -1: + # Found a victim to pause + victim_req = scheduled_requests.pop(victim_idx) + paused_requests.append(victim_req) + + # Revert allocation in C++ manager + if hasattr(self.kv_cache_manager, + "remove_sequence_from_scheduling"): + self.kv_cache_manager.remove_sequence_from_scheduling( + victim_req) + + # Do NOT increment idx. We retry the CURRENT request (req) + # because now there is more space. + continue + else: + # No victim found, and current request doesn't fit. + # Stop scheduling. + break + + # Filter Disagg Gen Init + fitting_requests = [] + fitting_disagg_gen_init = [] + for r in scheduled_requests: + if r.state == LlmRequestState.DISAGG_GENERATION_INIT: + fitting_disagg_gen_init.append(r) + else: + fitting_requests.append(r) + + return fitting_requests, fitting_disagg_gen_init, paused_requests + + def _schedule_guaranteed_no_evict(self, active_requests: RequestList): + scheduled_requests = [] + pending_requests = [] + + # 1. Simulate resource state + # We need to know the total number of blocks and how many are reserved by Running requests. + # Assuming KV Manager provides an interface to get the maximum resource count. + # If this is Pure Python Phase 1, we might need to call C++ binding or Shadow Manager. + max_blocks = self.kv_cache_manager.get_max_resource_count() + reserved_blocks = 0 + + # 2. First pass: Prioritize scheduling running requests (Running Requests) + # Core principle: No Eviction. As long as it is Generating, it must be retained. + for request in active_requests: + req_state = request.state + + # Filter out requests that cannot be scheduled yet + if (req_state.value < self.no_schedule_until_state.value + or req_state.value >= self.no_schedule_after_state.value): + continue + + # If the maximum number of requests is reached, or there isn't even enough memory for Running requests (extreme case), break. + # Note: GuaranteedNoEvict tries its best not to Evict, but if max_num_requests is full, there is no other way. + if len(scheduled_requests) >= self.max_num_requests: + pending_requests.append(request) + continue + + # Prioritize handling requests in the Generation phase + if (req_state == LlmRequestState.GENERATION_IN_PROGRESS + or req_state == LlmRequestState.GENERATION_TO_COMPLETE): + + # Calculate how many blocks are needed for this request to complete (Reserved to completion) + needed = self.kv_cache_manager.get_needed_resource_to_completion( + request) + + if reserved_blocks + needed > max_blocks: + # Extremely rare case: Memory fragmentation or overallocation causes Running requests to be unsustainable. + # At this point, we have to pause it (although the policy is named NoEvict, physical resource insufficiency is a hard constraint). + # But in standard implementation, we try to let it run. + pass + + scheduled_requests.append(request) + reserved_blocks += needed + else: + # Put Context requests into Pending queue first, try to schedule later + pending_requests.append(request) + + # 3. Second pass: Try to schedule new requests (Context Requests) + # Only after Running requests are satisfied, remaining resources are allocated to New Requests. + available_blocks = max_blocks - reserved_blocks + + for request in pending_requests: + if len(scheduled_requests) >= self.max_num_requests: + break + + # Handle Context Init or Disagg Gen Init + if (request.state == LlmRequestState.CONTEXT_INIT + or request.state == LlmRequestState.DISAGG_GENERATION_INIT): + + needed_blocks = self.kv_cache_manager.get_needed_resource_to_completion( + request) + + if needed_blocks <= available_blocks: + scheduled_requests.append(request) + available_blocks -= needed_blocks + else: + # Insufficient resources, cannot accept new requests. + # Because the policy is No Evict, we cannot pause Running requests to make room. + # So once we encounter one that doesn't fit, subsequent ones usually won't fit either (unless filled by small requests). + # To maintain FIFO, we usually break here. + break + + # 4. Construct return values + # Under this policy, paused_requests are usually not actively generated (unless active_requests itself has paused ones and no resources to resume) + # Simplified handling here: unscheduled ones are considered paused/waiting. + + # Categorize according to interface requirements + fitting_requests = [] + fitting_disagg_gen_init = [] + paused_requests = [ + ] # Active Requests not selected (Running state squeezed out) + + # Identify which Active Running requests were squeezed out (theoretically shouldn't happen, but for completeness) + scheduled_ids = set(r.request_id for r in scheduled_requests) + for req in active_requests: + if req.request_id not in scheduled_ids and req.state == LlmRequestState.GENERATION_IN_PROGRESS: + paused_requests.append(req) + + for req in scheduled_requests: + if req.state == LlmRequestState.DISAGG_GENERATION_INIT: + fitting_disagg_gen_init.append(req) + else: + fitting_requests.append(req) + + return fitting_requests, fitting_disagg_gen_init, paused_requests + + +class SimpleSPMDScheduler(RequestScheduler): + + def __init__( + self, + max_batch_size: int, + max_num_tokens: int, + kv_cache_manager, + scheduler_policy: CapacitySchedulerPolicy, + ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None, + ): + # 1. Initialize Python Capacity Scheduler + self.capacity_scheduler = PyCapacityScheduler( + max_num_requests=max_batch_size, + kv_cache_manager=kv_cache_manager, + scheduler_policy=scheduler_policy) + + # 2. Initialize Python MicroBatch Scheduler + py_chunk_config = None + if ctx_chunk_config: + # Convert StrEnum to our Python Enum + policy_enum = ChunkingPolicy.EQUAL_PROGRESS if ctx_chunk_config[ + 0] == tb_internal.batch_manager.ChunkingPolicy.EQUAL_PROGRESS else ChunkingPolicy.FIRST_COME_FIRST_SERVED + py_chunk_config = ContextChunkingConfig(policy_enum, + ctx_chunk_config[1]) + + self.micro_batch_scheduler = PyMicroBatchScheduler( + max_batch_size=max_batch_size, + max_num_tokens=max_num_tokens, + ctx_chunk_config=py_chunk_config) + + def schedule_request(self, active_requests: RequestList, + inflight_request_ids: set[int]) -> SchedulerOutput: + # Step 1: Capacity Check (Who fits in memory?) + fitting_requests, fitting_disagg_gen_init, paused_requests = self.capacity_scheduler.schedule_request( + active_requests) + + # Step 2: MicroBatch Check (Who fits in token budget? + Chunking) + context_requests, generation_requests = self.micro_batch_scheduler.schedule( + fitting_requests, inflight_request_ids) + + return SchedulerOutput( + context_requests=context_requests, + generation_requests=generation_requests, + paused_requests=paused_requests, + fitting_disagg_gen_init_requests=fitting_disagg_gen_init, + num_fitting_requests=len(fitting_requests)) + + def can_schedule(self, requests: RequestList) -> bool: + # Dry run capacity check + fitting, _, _ = self.capacity_scheduler.schedule_request(requests) + return len(fitting) == len(requests) From 034fffb0900632b2aefdec1ce9a5b5cbe63d22ba Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 13:53:31 +0800 Subject: [PATCH 02/25] refine Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 209 +++++++++++--------- 1 file changed, 118 insertions(+), 91 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index ae8dc923fb2..429da8d38cb 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -454,7 +454,14 @@ def _fit_draft_tokens(self, requests, capacity, unit_size): current_tokens += draft_needed -class PyCapacityScheduler(CapacityScheduler): +class PyCapacityScheduler: + """ + Python implementation of the C++ CapacityScheduler. + + It delegates the heavy lifting of block counting and tracking to the C++ + KVCacheManager via bindings, but controls the decision-making loop (policy) + in Python. + """ def __init__( self, @@ -465,7 +472,6 @@ def __init__( no_schedule_until_state=LlmRequestState.CONTEXT_INIT, no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE, ): - super().__init__() self.max_num_requests = max_num_requests self.kv_cache_manager = kv_cache_manager self.policy = scheduler_policy @@ -475,102 +481,127 @@ def __init__( def schedule_request( self, active_requests: RequestList ) -> Tuple[RequestList, RequestList, RequestList]: + """ + Main entry point. + Returns: (fitting_requests, fitting_disagg_gen_init, paused_requests) + """ if self.policy == CapacitySchedulerPolicy.MAX_UTILIZATION: return self._schedule_max_utilization(active_requests) elif self.policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: - # Reuse existing implementation logic or simple pass-through return self._schedule_guaranteed_no_evict(active_requests) else: raise NotImplementedError( f"Policy {self.policy} not implemented in PyCapacityScheduler") def _schedule_max_utilization(self, active_requests: RequestList): - scheduled_requests = [] - paused_requests = [] + """ + Greedy strategy with Backtracking. + 1. Try to schedule requests. + 2. If a request doesn't fit, try to 'pause' (evict) a running generation request + that was scheduled earlier in this loop to free up blocks. + 3. Retry the current request. + """ + scheduled_requests: RequestList = [] + paused_requests: RequestList = [] - # We need to simulate the C++ "BlockManager" state - # Since Phase 1 uses C++ Manager, we assume we call it. - # But C++ `startScheduling()` resets internal temp state. + # REQUIRED BINDING: C++ `startScheduling()` + # Resets the internal transactional state of block manager for this step. if hasattr(self.kv_cache_manager, "start_scheduling"): self.kv_cache_manager.start_scheduling() - # Iterate through all requests - # Logic: Try to schedule. If fail, see if we can pause a running request to make room. - - iter(active_requests) - cached_active_list = list(active_requests) # For reverse lookups - + # We use a while loop index because we might need to retry the *same* request + # after evicting a victim. + cached_active_list = list(active_requests) idx = 0 + while idx < len(cached_active_list): req = cached_active_list[idx] # 1. State Filter + # Skip requests that haven't reached INIT or are already DONE if (req.state.value < self.no_schedule_until_state.value or req.state.value >= self.no_schedule_after_state.value): - # Cannot schedule, but keep iterating idx += 1 continue - # 2. Max Requests Limit + # 2. Max Requests Check if len(scheduled_requests) >= self.max_num_requests: + # If we hit the request count limit, we can't schedule more. + # However, in C++ implementation, we might break or continue logic. + # Usually we break because sorting ensures priority. + # But here we simply treat unscheduled active requests as paused implicitly later. break - # 3. KV Cache Check (The Critical Part) - # We assume KV Manager has a `check_allocation` or `prepare_blocks` method - # that returns needed blocks or None if it doesn't fit. - # In C++ this is `blocksManager.prepareNewNumberOfBlocks...` - - # NOTE: For Phase 1, we might need to expose a helper in C++ binding - # if direct block math isn't exposed. - # Assuming `kv_cache_manager.check_and_update_allocation(req)` returns True/False - + # 3. Try Allocation (Atomic Check & Update) + # REQUIRED BINDING: `try_scheduling_request(req, max_requests_limit)` + # This binding should map to `trySchedulingRequestMaxUtilization` in C++. + # It performs the check: (available_blocks >= needed_blocks) + # If True, it commits the usage to the transaction and returns True. can_allocate = False try: - # This function implies C++ side logic: "If I add this req, do I have blocks?" - # It updates the internal transaction state of C++ manager. - can_allocate = self.kv_cache_manager.check_and_update_allocation( - req) + # Assuming binding takes (req, current_scheduled_count) or just (req) + # if manager doesn't track count. + can_allocate = self.kv_cache_manager.try_scheduling_request(req) except AttributeError: - # Fallback / Mock for logic understanding + # Fallback for development/mocking can_allocate = True if can_allocate: scheduled_requests.append(req) idx += 1 + continue + + # 4. Backtracking / Eviction Logic + # If we are here, 'req' did NOT fit. + # Can we pause a previously scheduled RUNNING request to make room? + + victim_idx = -1 + # Search backwards for a Generation request (we don't pause Context init usually) + for i in range(len(scheduled_requests) - 1, -1, -1): + r = scheduled_requests[i] + if r.state == LlmRequestState.GENERATION_IN_PROGRESS: + victim_idx = i + break + + if victim_idx != -1: + # Found a victim. Evict it. + victim_req = scheduled_requests.pop(victim_idx) + paused_requests.append(victim_req) + + # REQUIRED BINDING: `scheduling_remove_sequence(req_id)` + # Reverts the block usage of the victim in the current transaction. + if hasattr(self.kv_cache_manager, "scheduling_remove_sequence"): + self.kv_cache_manager.scheduling_remove_sequence( + victim_req.request_id) + + # CRITICAL: Do NOT increment `idx`. + # We loop back and try to schedule `req` again, now that space is freed. + continue else: - # 4. Backtracking / Pausing Logic - # If we failed to allocate, can we pause a *previously scheduled* Running request? - # Find the last scheduled request that is in GENERATION phase - victim_idx = -1 - for i in range(len(scheduled_requests) - 1, -1, -1): - r = scheduled_requests[i] - if r.state == LlmRequestState.GENERATION_IN_PROGRESS: - victim_idx = i - break + # No valid victim found, and current request doesn't fit. + # We cannot make progress. Stop scheduling. + break - if victim_idx != -1: - # Found a victim to pause - victim_req = scheduled_requests.pop(victim_idx) - paused_requests.append(victim_req) + # 5. Output Classification + # Any active request not in `scheduled_requests` is effectively paused/waiting. + # But `paused_requests` list contains specifically those we *actively* evicted. - # Revert allocation in C++ manager - if hasattr(self.kv_cache_manager, - "remove_sequence_from_scheduling"): - self.kv_cache_manager.remove_sequence_from_scheduling( - victim_req) + # We also need to capture requests that were active but we stopped loop before reaching them. + scheduled_ids = set(r.request_id for r in scheduled_requests) + evicted_ids = set(r.request_id for r in paused_requests) - # Do NOT increment idx. We retry the CURRENT request (req) - # because now there is more space. - continue - else: - # No victim found, and current request doesn't fit. - # Stop scheduling. - break + for req in active_requests: + if (req.state == LlmRequestState.GENERATION_IN_PROGRESS + and req.request_id not in scheduled_ids + and req.request_id not in evicted_ids): + # Request was running, but we ran out of slots/memory before processing it + # or we stopped scheduling. + paused_requests.append(req) - # Filter Disagg Gen Init fitting_requests = [] fitting_disagg_gen_init = [] + for r in scheduled_requests: if r.state == LlmRequestState.DISAGG_GENERATION_INIT: fitting_disagg_gen_init.append(r) @@ -580,61 +611,62 @@ def _schedule_max_utilization(self, active_requests: RequestList): return fitting_requests, fitting_disagg_gen_init, paused_requests def _schedule_guaranteed_no_evict(self, active_requests: RequestList): - scheduled_requests = [] - pending_requests = [] + """ + Conservative strategy. + 1. First, ensure ALL currently running requests have enough memory to run to COMPLETION. + If not, we technically shouldn't schedule them (or system is over-subscribed). + 2. Only then, use remaining memory for New (Context) requests. + """ + scheduled_requests: RequestList = [] + pending_requests: RequestList = [] - # 1. Simulate resource state - # We need to know the total number of blocks and how many are reserved by Running requests. - # Assuming KV Manager provides an interface to get the maximum resource count. - # If this is Pure Python Phase 1, we might need to call C++ binding or Shadow Manager. + # REQUIRED BINDING: `get_max_resource_count()` -> int max_blocks = self.kv_cache_manager.get_max_resource_count() reserved_blocks = 0 - # 2. First pass: Prioritize scheduling running requests (Running Requests) - # Core principle: No Eviction. As long as it is Generating, it must be retained. + # --- Pass 1: Running Requests (Priority) --- for request in active_requests: req_state = request.state - # Filter out requests that cannot be scheduled yet + # Filter valid states if (req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value): continue - # If the maximum number of requests is reached, or there isn't even enough memory for Running requests (extreme case), break. - # Note: GuaranteedNoEvict tries its best not to Evict, but if max_num_requests is full, there is no other way. + # Hard constraints check if len(scheduled_requests) >= self.max_num_requests: pending_requests.append(request) continue - # Prioritize handling requests in the Generation phase + # Check Generation Requests if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - # Calculate how many blocks are needed for this request to complete (Reserved to completion) + # REQUIRED BINDING: `get_needed_resource_to_completion(req)` -> int + # This calculates blocks needed for full generation length, not just next step. needed = self.kv_cache_manager.get_needed_resource_to_completion( request) if reserved_blocks + needed > max_blocks: - # Extremely rare case: Memory fragmentation or overallocation causes Running requests to be unsustainable. - # At this point, we have to pause it (although the policy is named NoEvict, physical resource insufficiency is a hard constraint). - # But in standard implementation, we try to let it run. - pass + # System Overload. Even though policy is NoEvict, we physically can't fit. + # We skip this request (effectively pausing it). + # NOTE: In strict NoEvict, this implies a system error or aggressive over-subscription. + pending_requests.append(request) + continue scheduled_requests.append(request) reserved_blocks += needed else: - # Put Context requests into Pending queue first, try to schedule later + # Defer Context/Init requests to Pass 2 pending_requests.append(request) - # 3. Second pass: Try to schedule new requests (Context Requests) - # Only after Running requests are satisfied, remaining resources are allocated to New Requests. + # --- Pass 2: New / Context Requests --- available_blocks = max_blocks - reserved_blocks for request in pending_requests: if len(scheduled_requests) >= self.max_num_requests: break - # Handle Context Init or Disagg Gen Init if (request.state == LlmRequestState.CONTEXT_INIT or request.state == LlmRequestState.DISAGG_GENERATION_INIT): @@ -645,26 +677,21 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests.append(request) available_blocks -= needed_blocks else: - # Insufficient resources, cannot accept new requests. - # Because the policy is No Evict, we cannot pause Running requests to make room. - # So once we encounter one that doesn't fit, subsequent ones usually won't fit either (unless filled by small requests). - # To maintain FIFO, we usually break here. + # Cannot fit new request. Since we cannot evict running requests, + # we stop here (Head-of-Line blocking behavior is typical for FIFO). break - # 4. Construct return values - # Under this policy, paused_requests are usually not actively generated (unless active_requests itself has paused ones and no resources to resume) - # Simplified handling here: unscheduled ones are considered paused/waiting. - - # Categorize according to interface requirements + # --- Output Classification --- fitting_requests = [] fitting_disagg_gen_init = [] - paused_requests = [ - ] # Active Requests not selected (Running state squeezed out) + paused_requests = [] - # Identify which Active Running requests were squeezed out (theoretically shouldn't happen, but for completeness) scheduled_ids = set(r.request_id for r in scheduled_requests) + + # Identify running requests that were implicitly paused for req in active_requests: - if req.request_id not in scheduled_ids and req.state == LlmRequestState.GENERATION_IN_PROGRESS: + if (req.request_id not in scheduled_ids + and req.state == LlmRequestState.GENERATION_IN_PROGRESS): paused_requests.append(req) for req in scheduled_requests: @@ -676,7 +703,7 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): return fitting_requests, fitting_disagg_gen_init, paused_requests -class SimpleSPMDScheduler(RequestScheduler): +class SimpleUnifiedScheduler(RequestScheduler): def __init__( self, From 927b417c7b63cbe97f1b21633b28a05a8954350d Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:01:26 +0800 Subject: [PATCH 03/25] enable SimpleUnifiedScheduler Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 35 ++++++++++++++++++------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 385d4d52a1a..9d1b7aa3e36 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -39,7 +39,7 @@ from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler, TRTLLMSampler) from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, - SimpleScheduler) + SimpleScheduler, SimpleUnifiedScheduler) from .seq_slot_manager import SeqSlotManager GB = 1 << 30 @@ -806,15 +806,30 @@ def create_py_executor_instance( if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager: scheduler_capacity += 1 - capacity_scheduler = BindCapacityScheduler( - scheduler_capacity, - kv_cache_manager.impl if kv_cache_manager is not None else None, - peft_cache_manager.impl if peft_cache_manager is not None else None, - scheduler_config.capacity_scheduler_policy, - two_step_lookahead=mapping.has_pp()) - mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens, - ctx_chunk_config) - scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) + use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1" + if use_python_scheduler: + if peft_cache_manager is not None: + logger.warning( + "PeftCacheManager is currently ignored by Python Scheduler (Phase 1)." + ) + + scheduler = SimpleUnifiedScheduler( + max_batch_size=max_batch_size, + max_num_tokens=max_num_tokens, + kv_cache_manager=kv_cache_manager.impl + if kv_cache_manager is not None else None, + scheduler_policy=scheduler_config.capacity_scheduler_policy, + ctx_chunk_config=ctx_chunk_config) + else: + capacity_scheduler = BindCapacityScheduler( + scheduler_capacity, + kv_cache_manager.impl if kv_cache_manager is not None else None, + peft_cache_manager.impl if peft_cache_manager is not None else None, + scheduler_config.capacity_scheduler_policy, + two_step_lookahead=mapping.has_pp()) + mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens, + ctx_chunk_config) + scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) config = model_engine.model.model_config.pretrained_config attention_type = AttentionTypeCpp.MLA if is_mla( From 3609b2013a8711c52dfd66a992ac20a37b4b351f Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:47:53 +0800 Subject: [PATCH 04/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 429da8d38cb..1f97ca20586 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -620,8 +620,7 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests: RequestList = [] pending_requests: RequestList = [] - # REQUIRED BINDING: `get_max_resource_count()` -> int - max_blocks = self.kv_cache_manager.get_max_resource_count() + max_blocks = self.kv_cache_manager.max_num_blocks reserved_blocks = 0 # --- Pass 1: Running Requests (Priority) --- @@ -642,9 +641,7 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - # REQUIRED BINDING: `get_needed_resource_to_completion(req)` -> int - # This calculates blocks needed for full generation length, not just next step. - needed = self.kv_cache_manager.get_needed_resource_to_completion( + needed = self.kv_cache_manager.get_remaining_blocks_to_completion( request) if reserved_blocks + needed > max_blocks: @@ -670,7 +667,7 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if (request.state == LlmRequestState.CONTEXT_INIT or request.state == LlmRequestState.DISAGG_GENERATION_INIT): - needed_blocks = self.kv_cache_manager.get_needed_resource_to_completion( + needed_blocks = self.kv_cache_manager.get_remaining_blocks_to_completion( request) if needed_blocks <= available_blocks: From c901b2198c421b314f2bfdaef7b289adfceb2b05 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:53:10 +0800 Subject: [PATCH 05/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 114 +++++++++----------- 1 file changed, 49 insertions(+), 65 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 1f97ca20586..1ec4af41376 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -495,23 +495,22 @@ def schedule_request( f"Policy {self.policy} not implemented in PyCapacityScheduler") def _schedule_max_utilization(self, active_requests: RequestList): - """ - Greedy strategy with Backtracking. - 1. Try to schedule requests. - 2. If a request doesn't fit, try to 'pause' (evict) a running generation request - that was scheduled earlier in this loop to free up blocks. - 3. Retry the current request. - """ scheduled_requests: RequestList = [] paused_requests: RequestList = [] - # REQUIRED BINDING: C++ `startScheduling()` - # Resets the internal transactional state of block manager for this step. + # Binding: Resets internal transactional state (e.g. mSchedulingNumFreeBlocks) if hasattr(self.kv_cache_manager, "start_scheduling"): self.kv_cache_manager.start_scheduling() - # We use a while loop index because we might need to retry the *same* request - # after evicting a victim. + # We track "scheduling" free blocks in Python to mimic C++ logic if binding missing + # C++ Manager tracks this internally via mSchedulingNumFreeBlocks. + # IF the C++ binding `try_scheduling_request` exists (custom added), use it. + # IF NOT, we must rely on `get_needed_blocks_one_step` and internal tracking. + + # Assuming for Phase 1 you WILL add `try_scheduling_request` to binding + # (It's the safest way to ensure parity). + # If you cannot add binding, let me know, I can write the manual accounting version. + cached_active_list = list(active_requests) idx = 0 @@ -519,45 +518,48 @@ def _schedule_max_utilization(self, active_requests: RequestList): req = cached_active_list[idx] # 1. State Filter - # Skip requests that haven't reached INIT or are already DONE - if (req.state.value < self.no_schedule_until_state.value + # Allow Disagg Gen Init to pass through (matching C++ logic) + is_disagg_init = ( + req.state == LlmRequestState.DISAGG_GENERATION_INIT) + + if not is_disagg_init and ( + req.state.value < self.no_schedule_until_state.value or req.state.value >= self.no_schedule_after_state.value): idx += 1 continue # 2. Max Requests Check if len(scheduled_requests) >= self.max_num_requests: - # If we hit the request count limit, we can't schedule more. - # However, in C++ implementation, we might break or continue logic. - # Usually we break because sorting ensures priority. - # But here we simply treat unscheduled active requests as paused implicitly later. break - # 3. Try Allocation (Atomic Check & Update) - # REQUIRED BINDING: `try_scheduling_request(req, max_requests_limit)` - # This binding should map to `trySchedulingRequestMaxUtilization` in C++. - # It performs the check: (available_blocks >= needed_blocks) - # If True, it commits the usage to the transaction and returns True. + # 3. Try Allocation + # Using the binding name you likely need to add: `try_schedule` + # or `check_and_update_allocation`. + # Since it's missing in your provided nanobind, I assume you will add it. can_allocate = False - try: - # Assuming binding takes (req, current_scheduled_count) or just (req) - # if manager doesn't track count. + if hasattr(self.kv_cache_manager, "try_scheduling_request"): can_allocate = self.kv_cache_manager.try_scheduling_request(req) - except AttributeError: - # Fallback for development/mocking - can_allocate = True + else: + # Fallback: Manual check (simplified parity) + # Note: This requires get_needed_blocks_one_step binding + needed = self.kv_cache_manager.get_needed_blocks_one_step( + req, False, 0) # window_size arg? + free_blocks = self.kv_cache_manager.get_num_free_blocks( + ) # This gets Physical free, not Transactional... + # RISK: Without C++ support for transactional state (mSchedulingNumFreeBlocks), + # Python side cannot accurately backtrack. + # STRONG RECOMMENDATION: Add `try_scheduling_request` to C++ binding. + if free_blocks >= needed: + # We can't easily "commit" this without C++ support + can_allocate = True if can_allocate: scheduled_requests.append(req) idx += 1 continue - # 4. Backtracking / Eviction Logic - # If we are here, 'req' did NOT fit. - # Can we pause a previously scheduled RUNNING request to make room? - + # 4. Backtracking victim_idx = -1 - # Search backwards for a Generation request (we don't pause Context init usually) for i in range(len(scheduled_requests) - 1, -1, -1): r = scheduled_requests[i] if r.state == LlmRequestState.GENERATION_IN_PROGRESS: @@ -565,22 +567,15 @@ def _schedule_max_utilization(self, active_requests: RequestList): break if victim_idx != -1: - # Found a victim. Evict it. victim_req = scheduled_requests.pop(victim_idx) paused_requests.append(victim_req) - # REQUIRED BINDING: `scheduling_remove_sequence(req_id)` - # Reverts the block usage of the victim in the current transaction. if hasattr(self.kv_cache_manager, "scheduling_remove_sequence"): self.kv_cache_manager.scheduling_remove_sequence( victim_req.request_id) - # CRITICAL: Do NOT increment `idx`. - # We loop back and try to schedule `req` again, now that space is freed. continue else: - # No valid victim found, and current request doesn't fit. - # We cannot make progress. Stop scheduling. break # 5. Output Classification @@ -611,55 +606,46 @@ def _schedule_max_utilization(self, active_requests: RequestList): return fitting_requests, fitting_disagg_gen_init, paused_requests def _schedule_guaranteed_no_evict(self, active_requests: RequestList): - """ - Conservative strategy. - 1. First, ensure ALL currently running requests have enough memory to run to COMPLETION. - If not, we technically shouldn't schedule them (or system is over-subscribed). - 2. Only then, use remaining memory for New (Context) requests. - """ scheduled_requests: RequestList = [] pending_requests: RequestList = [] + # FIX: available = Total - Used max_blocks = self.kv_cache_manager.max_num_blocks - reserved_blocks = 0 + used_blocks = self.kv_cache_manager.get_used_num_blocks( + ) # Binding `used_num_blocks` + available_blocks = max_blocks - used_blocks - # --- Pass 1: Running Requests (Priority) --- + # --- Pass 1: Running Requests --- for request in active_requests: req_state = request.state - # Filter valid states + # State Filter (Disagg Init is NOT a running request, handled in Pass 2) if (req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value): continue - # Hard constraints check if len(scheduled_requests) >= self.max_num_requests: pending_requests.append(request) continue - # Check Generation Requests if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): + # This returns needed ADDITIONAL blocks needed = self.kv_cache_manager.get_remaining_blocks_to_completion( request) - if reserved_blocks + needed > max_blocks: - # System Overload. Even though policy is NoEvict, we physically can't fit. - # We skip this request (effectively pausing it). - # NOTE: In strict NoEvict, this implies a system error or aggressive over-subscription. + if needed > available_blocks: + # Resource Exhaustion pending_requests.append(request) continue scheduled_requests.append(request) - reserved_blocks += needed + available_blocks -= needed # Reserve space else: - # Defer Context/Init requests to Pass 2 pending_requests.append(request) - # --- Pass 2: New / Context Requests --- - available_blocks = max_blocks - reserved_blocks - + # --- Pass 2: New Requests --- for request in pending_requests: if len(scheduled_requests) >= self.max_num_requests: break @@ -667,15 +653,13 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if (request.state == LlmRequestState.CONTEXT_INIT or request.state == LlmRequestState.DISAGG_GENERATION_INIT): - needed_blocks = self.kv_cache_manager.get_remaining_blocks_to_completion( + needed = self.kv_cache_manager.get_remaining_blocks_to_completion( request) - if needed_blocks <= available_blocks: + if needed <= available_blocks: scheduled_requests.append(request) - available_blocks -= needed_blocks + available_blocks -= needed else: - # Cannot fit new request. Since we cannot evict running requests, - # we stop here (Head-of-Line blocking behavior is typical for FIFO). break # --- Output Classification --- From 84cebc9a747aff59b1d38e7176c7eb259afeb24f Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:06:51 +0800 Subject: [PATCH 06/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 121 +++++++++----------- 1 file changed, 51 insertions(+), 70 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 1ec4af41376..d86bf7707ac 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -457,10 +457,6 @@ def _fit_draft_tokens(self, requests, capacity, unit_size): class PyCapacityScheduler: """ Python implementation of the C++ CapacityScheduler. - - It delegates the heavy lifting of block counting and tracking to the C++ - KVCacheManager via bindings, but controls the decision-making loop (policy) - in Python. """ def __init__( @@ -478,13 +474,12 @@ def __init__( self.no_schedule_until_state = no_schedule_until_state self.no_schedule_after_state = no_schedule_after_state + # [FIX]: Get this from config! + self.default_window_size = self.kv_cache_manager.max_sequence_length + def schedule_request( self, active_requests: RequestList ) -> Tuple[RequestList, RequestList, RequestList]: - """ - Main entry point. - Returns: (fitting_requests, fitting_disagg_gen_init, paused_requests) - """ if self.policy == CapacitySchedulerPolicy.MAX_UTILIZATION: return self._schedule_max_utilization(active_requests) @@ -498,18 +493,10 @@ def _schedule_max_utilization(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] - # Binding: Resets internal transactional state (e.g. mSchedulingNumFreeBlocks) - if hasattr(self.kv_cache_manager, "start_scheduling"): - self.kv_cache_manager.start_scheduling() - - # We track "scheduling" free blocks in Python to mimic C++ logic if binding missing - # C++ Manager tracks this internally via mSchedulingNumFreeBlocks. - # IF the C++ binding `try_scheduling_request` exists (custom added), use it. - # IF NOT, we must rely on `get_needed_blocks_one_step` and internal tracking. - - # Assuming for Phase 1 you WILL add `try_scheduling_request` to binding - # (It's the safest way to ensure parity). - # If you cannot add binding, let me know, I can write the manual accounting version. + # [FIX] Track free blocks manually in Python because C++ internal transactional state + # is not exposed/updated via bindings for tentative scheduling. + # get_num_free_blocks() returns the physical free blocks. + current_free_blocks = self.kv_cache_manager.get_num_free_blocks() cached_active_list = list(active_requests) idx = 0 @@ -532,33 +519,29 @@ def _schedule_max_utilization(self, active_requests: RequestList): if len(scheduled_requests) >= self.max_num_requests: break - # 3. Try Allocation - # Using the binding name you likely need to add: `try_schedule` - # or `check_and_update_allocation`. - # Since it's missing in your provided nanobind, I assume you will add it. - can_allocate = False - if hasattr(self.kv_cache_manager, "try_scheduling_request"): - can_allocate = self.kv_cache_manager.try_scheduling_request(req) - else: - # Fallback: Manual check (simplified parity) - # Note: This requires get_needed_blocks_one_step binding - needed = self.kv_cache_manager.get_needed_blocks_one_step( - req, False, 0) # window_size arg? - free_blocks = self.kv_cache_manager.get_num_free_blocks( - ) # This gets Physical free, not Transactional... - # RISK: Without C++ support for transactional state (mSchedulingNumFreeBlocks), - # Python side cannot accurately backtrack. - # STRONG RECOMMENDATION: Add `try_scheduling_request` to C++ binding. - if free_blocks >= needed: - # We can't easily "commit" this without C++ support - can_allocate = True - - if can_allocate: + # 3. Try Allocation (Python Manual Check) + # [FIX] Use get_needed_blocks_one_step instead of missing try_scheduling_request + needed_blocks = 0 + if is_disagg_init: + # Disagg Init needs special calculation usually same as Context + needed_blocks = self.kv_cache_manager.get_needed_blocks_one_step( + req, False, self.default_window_size) + elif req.state == LlmRequestState.GENERATION_IN_PROGRESS: + # Generation usually needs 0 or 1 block depending on boundary + needed_blocks = self.kv_cache_manager.get_needed_blocks_one_step( + req, False, self.default_window_size) + elif req.state == LlmRequestState.CONTEXT_INIT: + needed_blocks = self.kv_cache_manager.get_needed_blocks_one_step( + req, False, self.default_window_size) + + if current_free_blocks >= needed_blocks: + # Commit locally + current_free_blocks -= needed_blocks scheduled_requests.append(req) idx += 1 continue - # 4. Backtracking + # 4. Backtracking / Eviction Logic victim_idx = -1 for i in range(len(scheduled_requests) - 1, -1, -1): r = scheduled_requests[i] @@ -567,22 +550,25 @@ def _schedule_max_utilization(self, active_requests: RequestList): break if victim_idx != -1: + # Found a victim. Evict it. victim_req = scheduled_requests.pop(victim_idx) paused_requests.append(victim_req) - if hasattr(self.kv_cache_manager, "scheduling_remove_sequence"): - self.kv_cache_manager.scheduling_remove_sequence( - victim_req.request_id) + # [FIX] Reclaim victim's blocks manually + victim_needed = self.kv_cache_manager.get_needed_blocks_one_step( + victim_req, False, self.default_window_size) + current_free_blocks += victim_needed + # Retry current req without incrementing idx continue else: + # No victim found, and current request doesn't fit. break - # 5. Output Classification - # Any active request not in `scheduled_requests` is effectively paused/waiting. - # But `paused_requests` list contains specifically those we *actively* evicted. + # 5. Output Classification (Same as before) + fitting_requests = [] + fitting_disagg_gen_init = [] - # We also need to capture requests that were active but we stopped loop before reaching them. scheduled_ids = set(r.request_id for r in scheduled_requests) evicted_ids = set(r.request_id for r in paused_requests) @@ -590,13 +576,8 @@ def _schedule_max_utilization(self, active_requests: RequestList): if (req.state == LlmRequestState.GENERATION_IN_PROGRESS and req.request_id not in scheduled_ids and req.request_id not in evicted_ids): - # Request was running, but we ran out of slots/memory before processing it - # or we stopped scheduling. paused_requests.append(req) - fitting_requests = [] - fitting_disagg_gen_init = [] - for r in scheduled_requests: if r.state == LlmRequestState.DISAGG_GENERATION_INIT: fitting_disagg_gen_init.append(r) @@ -609,17 +590,18 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests: RequestList = [] pending_requests: RequestList = [] - # FIX: available = Total - Used max_blocks = self.kv_cache_manager.max_num_blocks - used_blocks = self.kv_cache_manager.get_used_num_blocks( - ) # Binding `used_num_blocks` + + # [FIX] Must subtract used blocks to get available for *new* allocations + used_blocks = self.kv_cache_manager.get_used_num_blocks() + + # We track 'reserved' as blocks we PLAN to add on top of 'used' available_blocks = max_blocks - used_blocks # --- Pass 1: Running Requests --- for request in active_requests: req_state = request.state - # State Filter (Disagg Init is NOT a running request, handled in Pass 2) if (req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value): continue @@ -631,21 +613,20 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - # This returns needed ADDITIONAL blocks + # [FIX] Added window_size argument needed = self.kv_cache_manager.get_remaining_blocks_to_completion( - request) + request, self.default_window_size) if needed > available_blocks: - # Resource Exhaustion pending_requests.append(request) continue scheduled_requests.append(request) - available_blocks -= needed # Reserve space + available_blocks -= needed else: pending_requests.append(request) - # --- Pass 2: New Requests --- + # --- Pass 2: New / Context Requests --- for request in pending_requests: if len(scheduled_requests) >= self.max_num_requests: break @@ -653,23 +634,23 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if (request.state == LlmRequestState.CONTEXT_INIT or request.state == LlmRequestState.DISAGG_GENERATION_INIT): - needed = self.kv_cache_manager.get_remaining_blocks_to_completion( - request) + # [FIX] Added window_size argument + needed_blocks = self.kv_cache_manager.get_remaining_blocks_to_completion( + request, self.default_window_size) - if needed <= available_blocks: + if needed_blocks <= available_blocks: scheduled_requests.append(request) - available_blocks -= needed + available_blocks -= needed_blocks else: break - # --- Output Classification --- + # --- Output Classification (Same as before) --- fitting_requests = [] fitting_disagg_gen_init = [] paused_requests = [] scheduled_ids = set(r.request_id for r in scheduled_requests) - # Identify running requests that were implicitly paused for req in active_requests: if (req.request_id not in scheduled_ids and req.state == LlmRequestState.GENERATION_IN_PROGRESS): From 490f8e977689a1e06aca66d10dafcaab370a4d7e Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:18:29 +0800 Subject: [PATCH 07/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 3 +-- tensorrt_llm/_torch/pyexecutor/scheduler.py | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 9d1b7aa3e36..7477953638d 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -816,8 +816,7 @@ def create_py_executor_instance( scheduler = SimpleUnifiedScheduler( max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, - kv_cache_manager=kv_cache_manager.impl - if kv_cache_manager is not None else None, + kv_cache_manager=kv_cache_manager, scheduler_policy=scheduler_config.capacity_scheduler_policy, ctx_chunk_config=ctx_chunk_config) else: diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index d86bf7707ac..e0a80155690 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -470,12 +470,13 @@ def __init__( ): self.max_num_requests = max_num_requests self.kv_cache_manager = kv_cache_manager + self.kv_cache_manager_cpp = kv_cache_manager.impl self.policy = scheduler_policy self.no_schedule_until_state = no_schedule_until_state self.no_schedule_after_state = no_schedule_after_state # [FIX]: Get this from config! - self.default_window_size = self.kv_cache_manager.max_sequence_length + self.default_window_size = self.kv_cache_manager.max_seq_len def schedule_request( self, active_requests: RequestList @@ -496,7 +497,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): # [FIX] Track free blocks manually in Python because C++ internal transactional state # is not exposed/updated via bindings for tentative scheduling. # get_num_free_blocks() returns the physical free blocks. - current_free_blocks = self.kv_cache_manager.get_num_free_blocks() + current_free_blocks = self.kv_cache_manager_cpp.get_num_free_blocks() cached_active_list = list(active_requests) idx = 0 @@ -524,14 +525,14 @@ def _schedule_max_utilization(self, active_requests: RequestList): needed_blocks = 0 if is_disagg_init: # Disagg Init needs special calculation usually same as Context - needed_blocks = self.kv_cache_manager.get_needed_blocks_one_step( + needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( req, False, self.default_window_size) elif req.state == LlmRequestState.GENERATION_IN_PROGRESS: # Generation usually needs 0 or 1 block depending on boundary - needed_blocks = self.kv_cache_manager.get_needed_blocks_one_step( + needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( req, False, self.default_window_size) elif req.state == LlmRequestState.CONTEXT_INIT: - needed_blocks = self.kv_cache_manager.get_needed_blocks_one_step( + needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( req, False, self.default_window_size) if current_free_blocks >= needed_blocks: @@ -555,7 +556,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): paused_requests.append(victim_req) # [FIX] Reclaim victim's blocks manually - victim_needed = self.kv_cache_manager.get_needed_blocks_one_step( + victim_needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( victim_req, False, self.default_window_size) current_free_blocks += victim_needed @@ -590,10 +591,10 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests: RequestList = [] pending_requests: RequestList = [] - max_blocks = self.kv_cache_manager.max_num_blocks + max_blocks = self.kv_cache_manager_cpp.max_num_blocks # [FIX] Must subtract used blocks to get available for *new* allocations - used_blocks = self.kv_cache_manager.get_used_num_blocks() + used_blocks = self.kv_cache_manager_cpp.get_used_num_blocks() # We track 'reserved' as blocks we PLAN to add on top of 'used' available_blocks = max_blocks - used_blocks @@ -614,7 +615,7 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): or req_state == LlmRequestState.GENERATION_TO_COMPLETE): # [FIX] Added window_size argument - needed = self.kv_cache_manager.get_remaining_blocks_to_completion( + needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( request, self.default_window_size) if needed > available_blocks: @@ -635,7 +636,7 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): or request.state == LlmRequestState.DISAGG_GENERATION_INIT): # [FIX] Added window_size argument - needed_blocks = self.kv_cache_manager.get_remaining_blocks_to_completion( + needed_blocks = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( request, self.default_window_size) if needed_blocks <= available_blocks: From 4e62403922025affa85fea346b608d56b9537cd8 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:26:00 +0800 Subject: [PATCH 08/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 34 +++++++++++++-------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index e0a80155690..2e08450f135 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -11,6 +11,7 @@ from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy from .llm_request import LlmRequest, LlmRequestState +from .resource_manager import KVCacheManager RequestList = list[LlmRequest] @@ -462,7 +463,7 @@ class PyCapacityScheduler: def __init__( self, max_num_requests: int, - kv_cache_manager, + kv_cache_manager: KVCacheManager, scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. MAX_UTILIZATION, no_schedule_until_state=LlmRequestState.CONTEXT_INIT, @@ -470,12 +471,13 @@ def __init__( ): self.max_num_requests = max_num_requests self.kv_cache_manager = kv_cache_manager + # Use the underlying C++ object via bindings self.kv_cache_manager_cpp = kv_cache_manager.impl self.policy = scheduler_policy self.no_schedule_until_state = no_schedule_until_state self.no_schedule_after_state = no_schedule_after_state - # [FIX]: Get this from config! + # [FIX]: Get this from config/wrapper to ensure C++ API compatibility self.default_window_size = self.kv_cache_manager.max_seq_len def schedule_request( @@ -494,9 +496,9 @@ def _schedule_max_utilization(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] - # [FIX] Track free blocks manually in Python because C++ internal transactional state - # is not exposed/updated via bindings for tentative scheduling. - # get_num_free_blocks() returns the physical free blocks. + # Track free blocks manually in Python to simulate transactional state. + # get_num_free_blocks() returns the current physical free blocks. + # We subtract from this as we tentatively schedule requests. current_free_blocks = self.kv_cache_manager_cpp.get_num_free_blocks() cached_active_list = list(active_requests) @@ -521,7 +523,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): break # 3. Try Allocation (Python Manual Check) - # [FIX] Use get_needed_blocks_one_step instead of missing try_scheduling_request + # Use get_needed_blocks_one_step to calculate incremental need needed_blocks = 0 if is_disagg_init: # Disagg Init needs special calculation usually same as Context @@ -536,13 +538,15 @@ def _schedule_max_utilization(self, active_requests: RequestList): req, False, self.default_window_size) if current_free_blocks >= needed_blocks: - # Commit locally + # Commit locally (transactional update) current_free_blocks -= needed_blocks scheduled_requests.append(req) idx += 1 continue # 4. Backtracking / Eviction Logic + # If current request doesn't fit, try to evict a previously scheduled + # GENERATION request to make room. victim_idx = -1 for i in range(len(scheduled_requests) - 1, -1, -1): r = scheduled_requests[i] @@ -555,18 +559,19 @@ def _schedule_max_utilization(self, active_requests: RequestList): victim_req = scheduled_requests.pop(victim_idx) paused_requests.append(victim_req) - # [FIX] Reclaim victim's blocks manually + # Reclaim victim's blocks manually + # We simply give back what we subtracted earlier for this victim. victim_needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( victim_req, False, self.default_window_size) current_free_blocks += victim_needed - # Retry current req without incrementing idx + # Retry current req (do NOT increment idx) continue else: - # No victim found, and current request doesn't fit. + # No victim found, and current request doesn't fit. Stop. break - # 5. Output Classification (Same as before) + # 5. Output Classification fitting_requests = [] fitting_disagg_gen_init = [] @@ -577,6 +582,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): if (req.state == LlmRequestState.GENERATION_IN_PROGRESS and req.request_id not in scheduled_ids and req.request_id not in evicted_ids): + # Request was running but dropped due to capacity/limit paused_requests.append(req) for r in scheduled_requests: @@ -619,12 +625,14 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): request, self.default_window_size) if needed > available_blocks: + # System Overload (should ideally not happen in NoEvict) pending_requests.append(request) continue scheduled_requests.append(request) available_blocks -= needed else: + # Defer Context/Init requests to Pass 2 pending_requests.append(request) # --- Pass 2: New / Context Requests --- @@ -643,15 +651,17 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests.append(request) available_blocks -= needed_blocks else: + # Cannot fit new request. Stop. break - # --- Output Classification (Same as before) --- + # --- Output Classification --- fitting_requests = [] fitting_disagg_gen_init = [] paused_requests = [] scheduled_ids = set(r.request_id for r in scheduled_requests) + # Identify running requests that were implicitly paused for req in active_requests: if (req.request_id not in scheduled_ids and req.state == LlmRequestState.GENERATION_IN_PROGRESS): From 87caccbd839cea53ce8a2136d88cc1c859c31f20 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:35:50 +0800 Subject: [PATCH 09/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 2e08450f135..abbebbe1a55 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -11,7 +11,6 @@ from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy from .llm_request import LlmRequest, LlmRequestState -from .resource_manager import KVCacheManager RequestList = list[LlmRequest] @@ -463,7 +462,7 @@ class PyCapacityScheduler: def __init__( self, max_num_requests: int, - kv_cache_manager: KVCacheManager, + kv_cache_manager, scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. MAX_UTILIZATION, no_schedule_until_state=LlmRequestState.CONTEXT_INIT, From d1aebe7ce357b40ecba40ecc6a9013cb6abb3630 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:43:16 +0800 Subject: [PATCH 10/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 38 ++++++++------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index abbebbe1a55..b94624808a3 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -495,10 +495,14 @@ def _schedule_max_utilization(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] + # [FIX] Use get_kv_cache_stats() to get block counts safely + if hasattr(self.kv_cache_manager, "start_scheduling"): + self.kv_cache_manager.start_scheduling() + + # Get snapshot of current state + stats = self.kv_cache_manager_cpp.get_kv_cache_stats() # Track free blocks manually in Python to simulate transactional state. - # get_num_free_blocks() returns the current physical free blocks. - # We subtract from this as we tentatively schedule requests. - current_free_blocks = self.kv_cache_manager_cpp.get_num_free_blocks() + current_free_blocks = stats.free_num_blocks cached_active_list = list(active_requests) idx = 0 @@ -507,7 +511,6 @@ def _schedule_max_utilization(self, active_requests: RequestList): req = cached_active_list[idx] # 1. State Filter - # Allow Disagg Gen Init to pass through (matching C++ logic) is_disagg_init = ( req.state == LlmRequestState.DISAGG_GENERATION_INIT) @@ -522,14 +525,11 @@ def _schedule_max_utilization(self, active_requests: RequestList): break # 3. Try Allocation (Python Manual Check) - # Use get_needed_blocks_one_step to calculate incremental need needed_blocks = 0 if is_disagg_init: - # Disagg Init needs special calculation usually same as Context needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( req, False, self.default_window_size) elif req.state == LlmRequestState.GENERATION_IN_PROGRESS: - # Generation usually needs 0 or 1 block depending on boundary needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( req, False, self.default_window_size) elif req.state == LlmRequestState.CONTEXT_INIT: @@ -537,15 +537,13 @@ def _schedule_max_utilization(self, active_requests: RequestList): req, False, self.default_window_size) if current_free_blocks >= needed_blocks: - # Commit locally (transactional update) + # Commit locally current_free_blocks -= needed_blocks scheduled_requests.append(req) idx += 1 continue # 4. Backtracking / Eviction Logic - # If current request doesn't fit, try to evict a previously scheduled - # GENERATION request to make room. victim_idx = -1 for i in range(len(scheduled_requests) - 1, -1, -1): r = scheduled_requests[i] @@ -559,15 +557,14 @@ def _schedule_max_utilization(self, active_requests: RequestList): paused_requests.append(victim_req) # Reclaim victim's blocks manually - # We simply give back what we subtracted earlier for this victim. victim_needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( victim_req, False, self.default_window_size) current_free_blocks += victim_needed - # Retry current req (do NOT increment idx) + # Retry current req without incrementing idx continue else: - # No victim found, and current request doesn't fit. Stop. + # No victim found, and current request doesn't fit. break # 5. Output Classification @@ -581,7 +578,6 @@ def _schedule_max_utilization(self, active_requests: RequestList): if (req.state == LlmRequestState.GENERATION_IN_PROGRESS and req.request_id not in scheduled_ids and req.request_id not in evicted_ids): - # Request was running but dropped due to capacity/limit paused_requests.append(req) for r in scheduled_requests: @@ -596,10 +592,10 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests: RequestList = [] pending_requests: RequestList = [] - max_blocks = self.kv_cache_manager_cpp.max_num_blocks - - # [FIX] Must subtract used blocks to get available for *new* allocations - used_blocks = self.kv_cache_manager_cpp.get_used_num_blocks() + # [FIX] Use get_kv_cache_stats() to fetch state atomically and robustly + stats = self.kv_cache_manager_cpp.get_kv_cache_stats() + max_blocks = stats.max_num_blocks + used_blocks = stats.used_num_blocks # We track 'reserved' as blocks we PLAN to add on top of 'used' available_blocks = max_blocks - used_blocks @@ -619,19 +615,16 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - # [FIX] Added window_size argument needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( request, self.default_window_size) if needed > available_blocks: - # System Overload (should ideally not happen in NoEvict) pending_requests.append(request) continue scheduled_requests.append(request) available_blocks -= needed else: - # Defer Context/Init requests to Pass 2 pending_requests.append(request) # --- Pass 2: New / Context Requests --- @@ -642,7 +635,6 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if (request.state == LlmRequestState.CONTEXT_INIT or request.state == LlmRequestState.DISAGG_GENERATION_INIT): - # [FIX] Added window_size argument needed_blocks = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( request, self.default_window_size) @@ -650,7 +642,6 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests.append(request) available_blocks -= needed_blocks else: - # Cannot fit new request. Stop. break # --- Output Classification --- @@ -660,7 +651,6 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_ids = set(r.request_id for r in scheduled_requests) - # Identify running requests that were implicitly paused for req in active_requests: if (req.request_id not in scheduled_ids and req.state == LlmRequestState.GENERATION_IN_PROGRESS): From 4d1f530580bae020b51b0034e6b11e2dfc192f17 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:46:01 +0800 Subject: [PATCH 11/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index b94624808a3..4aa9329a711 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -305,7 +305,7 @@ def schedule( elif req.state == LlmRequestState.CONTEXT_INIT: if not self.ctx_chunk_config: # No Chunking: Greedy allocation - req_num_tokens = req.get_context_remaining_length() + req_num_tokens = req.context_remaining_length draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 total_tokens = req_num_tokens + draft_tokens @@ -318,7 +318,7 @@ def schedule( current_batch_tokens += total_tokens else: # Chunking Enabled: Defer calculation - remaining = req.get_context_remaining_length() + remaining = req.context_remaining_length # Just an estimate for budget check req.context_chunk_size = remaining @@ -385,7 +385,7 @@ def _chunk_equal_progress(self, requests, capacity, unit_size): made_progress = False for req in requests: past_size = req.context_chunk_size - remaining = req.get_context_remaining_length() + remaining = req.context_remaining_length if past_size >= remaining: continue @@ -409,7 +409,7 @@ def _chunk_fcfs(self, requests, capacity, unit_size): current_capacity = capacity if capacity is not None else float('inf') for req in requests: - remaining = req.get_context_remaining_length() + remaining = req.context_remaining_length actual_size = remaining if current_capacity < actual_size: From 641236dd1861e848040f6f915a35cc354f461de2 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:59:46 +0800 Subject: [PATCH 12/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 4aa9329a711..61fb095fa65 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -684,9 +684,16 @@ def __init__( # 2. Initialize Python MicroBatch Scheduler py_chunk_config = None if ctx_chunk_config: - # Convert StrEnum to our Python Enum - policy_enum = ChunkingPolicy.EQUAL_PROGRESS if ctx_chunk_config[ - 0] == tb_internal.batch_manager.ChunkingPolicy.EQUAL_PROGRESS else ChunkingPolicy.FIRST_COME_FIRST_SERVED + # Fix: Use string comparison to identify the policy. + # This works regardless of whether the input is a Python Enum, C++ Binding Enum, or String. + input_policy = ctx_chunk_config[0] + + if "EQUAL_PROGRESS" in str(input_policy): + policy_enum = ChunkingPolicy.EQUAL_PROGRESS + else: + # Default to FCFS for FIRST_COME_FIRST_SERVED or others + policy_enum = ChunkingPolicy.FIRST_COME_FIRST_SERVED + py_chunk_config = ContextChunkingConfig(policy_enum, ctx_chunk_config[1]) From 162d59ea9ab4cb3d29aa70c49bc7631af6526ab5 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 17:46:35 +0800 Subject: [PATCH 13/25] enable py scheduler Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index cea56431b77..2f4ff469e55 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -17,6 +17,7 @@ # Disable UCC to WAR allgather issue before NGC PyTorch 25.12 upgrade. os.environ["OMPI_MCA_coll_ucc_enable"] = "0" +os.environ["TLLM_USE_PYTHON_SCHEDULER"] = "1" def _add_trt_llm_dll_directory(): From 707fb4a87b9953bf8cec189d6fe07ec3f1be9bb7 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:23:42 +0800 Subject: [PATCH 14/25] support bert Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 48 +++++++++++++++++++-- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 61fb095fa65..df186be3d31 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -470,19 +470,22 @@ def __init__( ): self.max_num_requests = max_num_requests self.kv_cache_manager = kv_cache_manager - # Use the underlying C++ object via bindings - self.kv_cache_manager_cpp = kv_cache_manager.impl self.policy = scheduler_policy self.no_schedule_until_state = no_schedule_until_state self.no_schedule_after_state = no_schedule_after_state - # [FIX]: Get this from config/wrapper to ensure C++ API compatibility - self.default_window_size = self.kv_cache_manager.max_seq_len + if self.kv_cache_manager is not None: + self.kv_cache_manager_cpp = kv_cache_manager.impl + self.default_window_size = self.kv_cache_manager.max_seq_len def schedule_request( self, active_requests: RequestList ) -> Tuple[RequestList, RequestList, RequestList]: + # 1. Handle No KV Cache Manager -> MaxRequestsScheduler Logic + if self.kv_cache_manager is None: + return self._schedule_max_requests(active_requests) + # 2. Handle Policies with KV Cache Manager if self.policy == CapacitySchedulerPolicy.MAX_UTILIZATION: return self._schedule_max_utilization(active_requests) elif self.policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: @@ -491,6 +494,43 @@ def schedule_request( raise NotImplementedError( f"Policy {self.policy} not implemented in PyCapacityScheduler") + def _schedule_max_requests(self, active_requests: RequestList): + """ + Simple scheduler that only limits the maximum number of requests. + Used when no KV Cache Manager is available. + """ + scheduled_requests: RequestList = [] + paused_requests: RequestList = [] + + for req in active_requests: + # 1. State Filter + if (req.state.value < self.no_schedule_until_state.value + or req.state.value >= self.no_schedule_after_state.value): + continue + + # 2. Max Requests Check + if len(scheduled_requests) >= self.max_num_requests: + break + + # 3. Schedule valid states + # Note: LlmRequest properties might vary, using state enum checks for safety + if (req.state == LlmRequestState.ENCODER_INIT + or req.state == LlmRequestState.CONTEXT_INIT + or req.state == LlmRequestState.GENERATION_IN_PROGRESS): + scheduled_requests.append(req) + + # Output Classification (Standard for all schedulers) + fitting_requests = [] + fitting_disagg_gen_init = [] + + for r in scheduled_requests: + if r.state == LlmRequestState.DISAGG_GENERATION_INIT: + fitting_disagg_gen_init.append(r) + else: + fitting_requests.append(r) + + return fitting_requests, fitting_disagg_gen_init, paused_requests + def _schedule_max_utilization(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] From fbc8486f8ad6dd0ffe07c78cbfd0a00a2916d9fd Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Thu, 18 Dec 2025 09:50:36 +0800 Subject: [PATCH 15/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 31 ++++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index df186be3d31..75a60eeef34 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -630,26 +630,33 @@ def _schedule_max_utilization(self, active_requests: RequestList): def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests: RequestList = [] + # Separate pending lists to match C++ priority logic pending_requests: RequestList = [] + pending_disagg_requests: RequestList = [] - # [FIX] Use get_kv_cache_stats() to fetch state atomically and robustly stats = self.kv_cache_manager_cpp.get_kv_cache_stats() max_blocks = stats.max_num_blocks used_blocks = stats.used_num_blocks - - # We track 'reserved' as blocks we PLAN to add on top of 'used' available_blocks = max_blocks - used_blocks # --- Pass 1: Running Requests --- for request in active_requests: req_state = request.state - if (req_state.value < self.no_schedule_until_state.value + is_disagg_init = ( + req_state == LlmRequestState.DISAGG_GENERATION_INIT) + + if not is_disagg_init and ( + req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value): continue if len(scheduled_requests) >= self.max_num_requests: - pending_requests.append(request) + # Still check state to sort into correct pending list + if is_disagg_init: + pending_disagg_requests.append(request) + else: + pending_requests.append(request) continue if (req_state == LlmRequestState.GENERATION_IN_PROGRESS @@ -665,10 +672,19 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests.append(request) available_blocks -= needed else: - pending_requests.append(request) + # Add to pending lists based on type + if is_disagg_init: + pending_disagg_requests.append(request) + else: + pending_requests.append(request) # --- Pass 2: New / Context Requests --- - for request in pending_requests: + # C++ logic prioritizes Disagg Generation Init requests over standard Context Init + # So we iterate pending_disagg_requests FIRST, then pending_requests + + all_pending = pending_disagg_requests + pending_requests + + for request in all_pending: if len(scheduled_requests) >= self.max_num_requests: break @@ -682,6 +698,7 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests.append(request) available_blocks -= needed_blocks else: + # Head-of-line blocking logic (standard in NoEvict) break # --- Output Classification --- From d3446704cb0dea8d65dcddd29eeceeecdb44ef65 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Thu, 18 Dec 2025 09:58:10 +0800 Subject: [PATCH 16/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 113 ++++++++------------ 1 file changed, 46 insertions(+), 67 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 75a60eeef34..c01a41d0177 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -481,6 +481,7 @@ def __init__( def schedule_request( self, active_requests: RequestList ) -> Tuple[RequestList, RequestList, RequestList]: + # 1. Handle No KV Cache Manager -> MaxRequestsScheduler Logic if self.kv_cache_manager is None: return self._schedule_max_requests(active_requests) @@ -495,16 +496,17 @@ def schedule_request( f"Policy {self.policy} not implemented in PyCapacityScheduler") def _schedule_max_requests(self, active_requests: RequestList): - """ - Simple scheduler that only limits the maximum number of requests. - Used when no KV Cache Manager is available. - """ scheduled_requests: RequestList = [] paused_requests: RequestList = [] for req in active_requests: # 1. State Filter - if (req.state.value < self.no_schedule_until_state.value + # Allow Disagg Gen Init to pass through + is_disagg_init = ( + req.state == LlmRequestState.DISAGG_GENERATION_INIT) + + if not is_disagg_init and ( + req.state.value < self.no_schedule_until_state.value or req.state.value >= self.no_schedule_after_state.value): continue @@ -513,35 +515,24 @@ def _schedule_max_requests(self, active_requests: RequestList): break # 3. Schedule valid states - # Note: LlmRequest properties might vary, using state enum checks for safety if (req.state == LlmRequestState.ENCODER_INIT or req.state == LlmRequestState.CONTEXT_INIT - or req.state == LlmRequestState.GENERATION_IN_PROGRESS): + or req.state == LlmRequestState.GENERATION_IN_PROGRESS + or is_disagg_init): scheduled_requests.append(req) - # Output Classification (Standard for all schedulers) - fitting_requests = [] - fitting_disagg_gen_init = [] - - for r in scheduled_requests: - if r.state == LlmRequestState.DISAGG_GENERATION_INIT: - fitting_disagg_gen_init.append(r) - else: - fitting_requests.append(r) - - return fitting_requests, fitting_disagg_gen_init, paused_requests + return self._classify_output(active_requests, scheduled_requests, + paused_requests) def _schedule_max_utilization(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] - # [FIX] Use get_kv_cache_stats() to get block counts safely if hasattr(self.kv_cache_manager, "start_scheduling"): self.kv_cache_manager.start_scheduling() - # Get snapshot of current state - stats = self.kv_cache_manager_cpp.get_kv_cache_stats() # Track free blocks manually in Python to simulate transactional state. + stats = self.kv_cache_manager_cpp.get_kv_cache_stats() current_free_blocks = stats.free_num_blocks cached_active_list = list(active_requests) @@ -551,6 +542,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): req = cached_active_list[idx] # 1. State Filter + # Allow Disagg Gen Init to pass through (matching C++ logic) is_disagg_init = ( req.state == LlmRequestState.DISAGG_GENERATION_INIT) @@ -577,7 +569,6 @@ def _schedule_max_utilization(self, active_requests: RequestList): req, False, self.default_window_size) if current_free_blocks >= needed_blocks: - # Commit locally current_free_blocks -= needed_blocks scheduled_requests.append(req) idx += 1 @@ -604,35 +595,18 @@ def _schedule_max_utilization(self, active_requests: RequestList): # Retry current req without incrementing idx continue else: - # No victim found, and current request doesn't fit. + # No victim found, and current request doesn't fit. Stop. break - # 5. Output Classification - fitting_requests = [] - fitting_disagg_gen_init = [] - - scheduled_ids = set(r.request_id for r in scheduled_requests) - evicted_ids = set(r.request_id for r in paused_requests) - - for req in active_requests: - if (req.state == LlmRequestState.GENERATION_IN_PROGRESS - and req.request_id not in scheduled_ids - and req.request_id not in evicted_ids): - paused_requests.append(req) - - for r in scheduled_requests: - if r.state == LlmRequestState.DISAGG_GENERATION_INIT: - fitting_disagg_gen_init.append(r) - else: - fitting_requests.append(r) - - return fitting_requests, fitting_disagg_gen_init, paused_requests + return self._classify_output(active_requests, scheduled_requests, + paused_requests) def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests: RequestList = [] - # Separate pending lists to match C++ priority logic - pending_requests: RequestList = [] + + # Pending lists separated to enforce priority: Disagg Init > Context Init pending_disagg_requests: RequestList = [] + pending_context_requests: RequestList = [] stats = self.kv_cache_manager_cpp.get_kv_cache_stats() max_blocks = stats.max_num_blocks @@ -642,23 +616,24 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): # --- Pass 1: Running Requests --- for request in active_requests: req_state = request.state - is_disagg_init = ( req_state == LlmRequestState.DISAGG_GENERATION_INIT) + # Filter valid states (Allow Disagg Init) if not is_disagg_init and ( req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value): continue + # Capacity Check if len(scheduled_requests) >= self.max_num_requests: - # Still check state to sort into correct pending list if is_disagg_init: pending_disagg_requests.append(request) else: - pending_requests.append(request) + pending_context_requests.append(request) continue + # Prioritize Running Requests (Generation) if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): @@ -666,50 +641,54 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): request, self.default_window_size) if needed > available_blocks: - pending_requests.append(request) + # If running req doesn't fit, skip it (effectively pause) + pending_context_requests.append(request) continue scheduled_requests.append(request) available_blocks -= needed else: - # Add to pending lists based on type + # Non-running requests go to pending if is_disagg_init: pending_disagg_requests.append(request) else: - pending_requests.append(request) + pending_context_requests.append(request) # --- Pass 2: New / Context Requests --- - # C++ logic prioritizes Disagg Generation Init requests over standard Context Init - # So we iterate pending_disagg_requests FIRST, then pending_requests - - all_pending = pending_disagg_requests + pending_requests + # Critical: Process Disagg Init requests BEFORE standard Context Init + all_pending = pending_disagg_requests + pending_context_requests for request in all_pending: if len(scheduled_requests) >= self.max_num_requests: break - if (request.state == LlmRequestState.CONTEXT_INIT - or request.state == LlmRequestState.DISAGG_GENERATION_INIT): + # Note: For Disagg Init, get_remaining_blocks_to_completion calculates + # full prompt + generation needs, which is what we want. + needed_blocks = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( + request, self.default_window_size) - needed_blocks = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - request, self.default_window_size) + if needed_blocks <= available_blocks: + scheduled_requests.append(request) + available_blocks -= needed_blocks + else: + # Head-of-line blocking (Standard NoEvict behavior) + break - if needed_blocks <= available_blocks: - scheduled_requests.append(request) - available_blocks -= needed_blocks - else: - # Head-of-line blocking logic (standard in NoEvict) - break + return self._classify_output(active_requests, scheduled_requests, []) - # --- Output Classification --- + def _classify_output(self, active_requests, scheduled_requests, + explicit_paused_requests): fitting_requests = [] fitting_disagg_gen_init = [] - paused_requests = [] + paused_requests = list(explicit_paused_requests) scheduled_ids = set(r.request_id for r in scheduled_requests) + paused_ids = set(r.request_id for r in paused_requests) + # Identify running requests that were implicitly paused (dropped) for req in active_requests: if (req.request_id not in scheduled_ids + and req.request_id not in paused_ids and req.state == LlmRequestState.GENERATION_IN_PROGRESS): paused_requests.append(req) From 6617a47d0136828a2a3cfe67d6f888c9e88b54b4 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Thu, 18 Dec 2025 10:09:59 +0800 Subject: [PATCH 17/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index c01a41d0177..32e2a31483f 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -619,33 +619,28 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): is_disagg_init = ( req_state == LlmRequestState.DISAGG_GENERATION_INIT) - # Filter valid states (Allow Disagg Init) if not is_disagg_init and ( req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value): continue - # Capacity Check if len(scheduled_requests) >= self.max_num_requests: + # Still sort into correct pending list for potential future logic if is_disagg_init: pending_disagg_requests.append(request) else: pending_context_requests.append(request) continue - # Prioritize Running Requests (Generation) if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( request, self.default_window_size) - if needed > available_blocks: - # If running req doesn't fit, skip it (effectively pause) - pending_context_requests.append(request) - continue - scheduled_requests.append(request) + # Decrement available. It is allowed to go negative (temporarily) + # to represent over-subscription, which stops new requests in Pass 2. available_blocks -= needed else: # Non-running requests go to pending @@ -662,8 +657,6 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if len(scheduled_requests) >= self.max_num_requests: break - # Note: For Disagg Init, get_remaining_blocks_to_completion calculates - # full prompt + generation needs, which is what we want. needed_blocks = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( request, self.default_window_size) @@ -671,7 +664,7 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests.append(request) available_blocks -= needed_blocks else: - # Head-of-line blocking (Standard NoEvict behavior) + # Head-of-line blocking break return self._classify_output(active_requests, scheduled_requests, []) From 63c09c64ef022b1609c89efc02c5ca39ebca4755 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Thu, 18 Dec 2025 10:17:22 +0800 Subject: [PATCH 18/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 76 +++++++++++++-------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 32e2a31483f..2064c5827b0 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -495,6 +495,39 @@ def schedule_request( raise NotImplementedError( f"Policy {self.policy} not implemented in PyCapacityScheduler") + def _get_req_needed_blocks(self, req: LlmRequest, + is_guaranteed_no_evict: bool) -> int: + """ + Robustly calculate needed blocks. + Tries to use C++ API if safe, otherwise falls back to Python math. + """ + # For GuaranteedNoEvict, we need Total Blocks to completion. + # For MaxUtilization, we typically need incremental blocks (1 step). + + # Scenario A: Disagg Init or Context Init (New Requests) + # The sequence might NOT be in C++ Manager yet. Do math manually to be safe. + if req.state == LlmRequestState.DISAGG_GENERATION_INIT or req.state == LlmRequestState.CONTEXT_INIT: + if is_guaranteed_no_evict: + # Need TOTAL blocks for the whole request + total_len = req.prompt_len + req.max_new_tokens + return math.ceil(total_len / self.tokens_per_block) + else: + # Need just the Context blocks for now (Simplified) + # Or use C++ get_needed_blocks_one_step if it doesn't lookup sequence + return math.ceil(req.prompt_len / self.tokens_per_block) + + # Scenario B: Generation (Running Requests) + # Sequence IS in C++ Manager. Use C++ API for accuracy (shared blocks, beam width, etc). + if req.state == LlmRequestState.GENERATION_IN_PROGRESS or req.state == LlmRequestState.GENERATION_TO_COMPLETE: + if is_guaranteed_no_evict: + return self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( + req, self.default_window_size) + else: + return self.kv_cache_manager_cpp.get_needed_blocks_one_step( + req, False, self.default_window_size) + + return 0 + def _schedule_max_requests(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] @@ -540,9 +573,6 @@ def _schedule_max_utilization(self, active_requests: RequestList): while idx < len(cached_active_list): req = cached_active_list[idx] - - # 1. State Filter - # Allow Disagg Gen Init to pass through (matching C++ logic) is_disagg_init = ( req.state == LlmRequestState.DISAGG_GENERATION_INIT) @@ -552,21 +582,12 @@ def _schedule_max_utilization(self, active_requests: RequestList): idx += 1 continue - # 2. Max Requests Check if len(scheduled_requests) >= self.max_num_requests: break - # 3. Try Allocation (Python Manual Check) - needed_blocks = 0 - if is_disagg_init: - needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, self.default_window_size) - elif req.state == LlmRequestState.GENERATION_IN_PROGRESS: - needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, self.default_window_size) - elif req.state == LlmRequestState.CONTEXT_INIT: - needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, self.default_window_size) + # 3. Try Allocation + needed_blocks = self._get_req_needed_blocks( + req, is_guaranteed_no_evict=False) if current_free_blocks >= needed_blocks: current_free_blocks -= needed_blocks @@ -574,7 +595,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): idx += 1 continue - # 4. Backtracking / Eviction Logic + # 4. Backtracking victim_idx = -1 for i in range(len(scheduled_requests) - 1, -1, -1): r = scheduled_requests[i] @@ -588,10 +609,9 @@ def _schedule_max_utilization(self, active_requests: RequestList): paused_requests.append(victim_req) # Reclaim victim's blocks manually - victim_needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - victim_req, False, self.default_window_size) + victim_needed = self._get_req_needed_blocks( + victim_req, is_guaranteed_no_evict=False) current_free_blocks += victim_needed - # Retry current req without incrementing idx continue else: @@ -619,28 +639,28 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): is_disagg_init = ( req_state == LlmRequestState.DISAGG_GENERATION_INIT) + # Filter valid states (Allow Disagg Init) if not is_disagg_init and ( req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value): continue + # Capacity Check if len(scheduled_requests) >= self.max_num_requests: - # Still sort into correct pending list for potential future logic if is_disagg_init: pending_disagg_requests.append(request) else: pending_context_requests.append(request) continue + # Prioritize Running Requests (Generation) + # In C++ NoEvict, running requests are scheduled unconditionally in the first pass if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - request, self.default_window_size) - + needed = self._get_req_needed_blocks( + request, is_guaranteed_no_evict=True) scheduled_requests.append(request) - # Decrement available. It is allowed to go negative (temporarily) - # to represent over-subscription, which stops new requests in Pass 2. available_blocks -= needed else: # Non-running requests go to pending @@ -657,14 +677,14 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): if len(scheduled_requests) >= self.max_num_requests: break - needed_blocks = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - request, self.default_window_size) + needed_blocks = self._get_req_needed_blocks( + request, is_guaranteed_no_evict=True) if needed_blocks <= available_blocks: scheduled_requests.append(request) available_blocks -= needed_blocks else: - # Head-of-line blocking + # Head-of-line blocking (Standard NoEvict behavior) break return self._classify_output(active_requests, scheduled_requests, []) From c2bffa5ae557491f5c1d9b35e0ce6fbd134e0568 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Thu, 18 Dec 2025 10:20:02 +0800 Subject: [PATCH 19/25] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 344 +++++++++++--------- 1 file changed, 188 insertions(+), 156 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 2064c5827b0..e5466f6a158 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -10,6 +10,7 @@ from tensorrt_llm.bindings import internal as tb_internal from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy +# Assuming these imports exist in your environment from .llm_request import LlmRequest, LlmRequestState RequestList = list[LlmRequest] @@ -247,6 +248,10 @@ class ContextChunkingConfig: chunk_unit_size: int +class MicroBatchScheduler: + """Base class to match structure.""" + + class PyMicroBatchScheduler(MicroBatchScheduler): def __init__( @@ -259,6 +264,7 @@ def __init__( self.max_batch_size = max_batch_size self.max_num_tokens = max_num_tokens self.ctx_chunk_config = ctx_chunk_config + self.max_context_length = max_num_tokens def schedule( self, active_requests: RequestList, @@ -267,102 +273,137 @@ def schedule( context_requests: RequestList = [] generation_requests: RequestList = [] - current_batch_tokens = 0 - scheduled_req_count = 0 + # Current total tokens in the scheduled batch (Generation + Context) + batch_num_tokens = 0 + scheduled_req_size = 0 scheduled_beam_width = 0 contexts_to_be_chunked: RequestList = [] + # Total tokens required by chunked requests (calculated tentatively) num_chunked_tokens = 0 - all_context_fits = True + all_context_requests_fit = True - # 1. First Pass: Filter & Categorize (Generation First) + # 1. Main Scheduling Loop for req in active_requests: - # Skip invalid states (Simplified check, assuming caller filters mostly) + # Skip requests already in flight (should be filtered by caller, but C++ checks) if req.request_id in inflight_request_ids: continue - # --- Generation Handling --- - if req.state == LlmRequestState.GENERATION_IN_PROGRESS: - beam_width = req.sampling_config.beam_width - req_num_tokens = beam_width + req.num_draft_tokens + req_num_tokens = 0 + + # --- A. Encoder Request Handling (Previously Missing) --- + if req.state == LlmRequestState.ENCODER_INIT: + # C++: reqNumTokens = llmReq->getEncoderOutputLen(); + req_num_tokens = req.encoder_output_len - # Check Global Token Budget - if self.max_num_tokens is not None and (current_batch_tokens + + if self.max_context_length is not None and req_num_tokens > self.max_context_length: + # C++ does TLLM_CHECK here. We skip or log. + continue + + # Check Batch Token Budget + if self.max_num_tokens is not None and (batch_num_tokens + req_num_tokens > self.max_num_tokens): break - # Check Beam Width Consistency (Batch constraint) - if scheduled_beam_width == 0: - scheduled_beam_width = beam_width - elif scheduled_beam_width != beam_width: - continue - - generation_requests.append(req) - current_batch_tokens += req_num_tokens + context_requests.append(req) + batch_num_tokens += req_num_tokens - # --- Context Handling --- + # --- B. Context Request Handling --- elif req.state == LlmRequestState.CONTEXT_INIT: if not self.ctx_chunk_config: - # No Chunking: Greedy allocation - req_num_tokens = req.context_remaining_length + # No Chunking: Schedule full context + # C++: getNumTokens(beam) + (hasDraft ? getNumDraftTokens : 0) + base_tokens = req.context_remaining_length # effectively getNumTokens(0) draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 - total_tokens = req_num_tokens + draft_tokens + req_num_tokens = base_tokens + draft_tokens + + if self.max_context_length is not None and req_num_tokens > self.max_context_length: + continue if self.max_num_tokens is not None and ( - current_batch_tokens + total_tokens + batch_num_tokens + req_num_tokens > self.max_num_tokens): break context_requests.append(req) - current_batch_tokens += total_tokens + batch_num_tokens += req_num_tokens else: - # Chunking Enabled: Defer calculation - remaining = req.context_remaining_length - # Just an estimate for budget check - req.context_chunk_size = remaining + # Chunking Enabled: Tentative schedule + # C++: setContextChunkSize(remaining); reqNumTokens = size + draft + req.context_chunk_size = req.context_remaining_length draft_tokens = req.num_draft_tokens if ( req.is_last_context_chunk and req.has_draft_tokens) else 0 - req_num_tokens = remaining + draft_tokens + req_num_tokens = req.context_chunk_size + draft_tokens + + # C++: Check maxContextLength constraints + if self.max_context_length is not None: + if self.max_context_length < req_num_tokens: + req_num_tokens = self.max_context_length + all_context_requests_fit = False contexts_to_be_chunked.append(req) num_chunked_tokens += req_num_tokens - # Batch Size Check - scheduled_req_count += 1 - if scheduled_req_count >= self.max_batch_size: + # --- C. Generation Request Handling --- + else: + beam_width = req.sampling_config.beam_width + req_num_tokens = beam_width + req.num_draft_tokens + + if self.max_num_tokens is not None and (batch_num_tokens + + req_num_tokens + > self.max_num_tokens): + break + + # Beam Width Consistency Check (C++ Logic) + if scheduled_beam_width == 0: + scheduled_beam_width = beam_width + elif scheduled_beam_width != beam_width: + # Skip requests with different beam width in this batch + continue + + generation_requests.append(req) + batch_num_tokens += req_num_tokens + + # --- Batch Size Limit Check --- + scheduled_req_size += 1 + if scheduled_req_size >= self.max_batch_size: break - # 2. Check if chunking logic is needed + # 2. Verify Chunking Fits if self.max_num_tokens is not None and num_chunked_tokens > ( - self.max_num_tokens - current_batch_tokens): - all_context_fits = False + self.max_num_tokens - batch_num_tokens): + all_context_requests_fit = False - # 3. Apply Chunking Strategy - if not all_context_fits and contexts_to_be_chunked: + # 3. Apply Chunking Strategy if needed + if not all_context_requests_fit and contexts_to_be_chunked: if not self.ctx_chunk_config: - # Should effectively be handled above, but as a fallback - pass + pass # Error in C++: "If chunking not enabled..." else: remaining_capacity = ( - self.max_num_tokens - current_batch_tokens + self.max_num_tokens - batch_num_tokens ) if self.max_num_tokens is not None else None + self._set_ctx_requests_chunk_size(contexts_to_be_chunked, remaining_capacity) - # 4. Finalize Context Requests + # 4. Finalize Chunked Requests for req in contexts_to_be_chunked: if req.context_chunk_size > 0: context_requests.append(req) - current_batch_tokens += req.context_chunk_size + # C++: batchNumTokens += chunk size + batch_num_tokens += req.context_chunk_size + + # Note: C++ calls utils::sortRequests here. Python lists preserve order, + # usually acceptable unless specific downstream kernel requirements exist. return context_requests, generation_requests def _set_ctx_requests_chunk_size(self, requests: RequestList, capacity: Optional[int]): - # Reset + # C++: Resets all chunk sizes to 0 at start for req in requests: req.context_chunk_size = 0 @@ -374,84 +415,95 @@ def _set_ctx_requests_chunk_size(self, requests: RequestList, elif policy == ChunkingPolicy.FIRST_COME_FIRST_SERVED: self._chunk_fcfs(requests, capacity, unit_size) - # Optimization: Fit draft tokens if space allows self._fit_draft_tokens(requests, capacity, unit_size) - def _chunk_equal_progress(self, requests, capacity, unit_size): + def _chunk_equal_progress(self, requests: RequestList, + capacity: Optional[int], unit_size: int): num_ctx_tokens = 0 - made_progress = True + num_tokens_single_loop = 1 - while (capacity is None or num_ctx_tokens < capacity) and made_progress: - made_progress = False + # C++ Loop: while ((!capacity || numCtxTokens < capacity) && numTokensSingleLoop) + while (capacity is None + or num_ctx_tokens < capacity) and num_tokens_single_loop > 0: + num_tokens_single_loop = 0 for req in requests: past_size = req.context_chunk_size - remaining = req.context_remaining_length - if past_size >= remaining: + # C++ logic: suggested = past + unit + suggested_size = past_size + unit_size + + # Ensure we don't exceed what the request actually needs + remaining_total = req.context_remaining_length + suggested_size = min(suggested_size, remaining_total) + + req.context_chunk_size = suggested_size + + actual_size = req.context_chunk_size + actual_increment = actual_size - past_size + + # Check Constraints + # 1. Capacity + if capacity is not None and (num_ctx_tokens + actual_increment + > capacity): + req.context_chunk_size = past_size # Revert continue - suggested_size = past_size + unit_size - actual_size = min(suggested_size, remaining) - increment = actual_size - past_size - - if increment > 0: - if capacity is not None and (num_ctx_tokens + increment - > capacity): - # Cannot fit this increment, stop growing this request - req.context_chunk_size = past_size - continue + # 2. Max Context Length + if self.max_context_length is not None and actual_size > self.max_context_length: + req.context_chunk_size = past_size # Revert + continue - req.context_chunk_size = actual_size - num_ctx_tokens += increment - made_progress = True + num_ctx_tokens += actual_increment + num_tokens_single_loop += actual_increment - def _chunk_fcfs(self, requests, capacity, unit_size): + def _chunk_fcfs(self, requests: RequestList, capacity: Optional[int], + unit_size: int): current_capacity = capacity if capacity is not None else float('inf') for req in requests: - remaining = req.context_remaining_length - actual_size = remaining + suggested_size = req.context_remaining_length + actual_size = suggested_size if current_capacity < actual_size: actual_size = current_capacity - # Align if truncated - if actual_size < remaining: + if self.max_context_length is not None: + actual_size = min(self.max_context_length, actual_size) + + # Round down to unit size if we had to truncate + if actual_size < suggested_size: actual_size = (int(actual_size) // unit_size) * unit_size req.context_chunk_size = int(actual_size) - current_capacity -= req.context_chunk_size - if current_capacity <= 0: - break + # C++: ctxTokensCapacity = ctxTokensCapacity - actualChunkSize + if capacity is not None: + current_capacity -= req.context_chunk_size - def _fit_draft_tokens(self, requests, capacity, unit_size): - # Python port of fitDraftTokens - # Logic: If it is the last chunk, try to fit draft tokens without using a new KV block - current_tokens = sum(r.context_chunk_size for r in requests) + def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], + unit_size: int): + # Calculate tokens already taken by the batch so far + num_ctx_tokens = sum(req.context_chunk_size for req in requests) for req in requests: if req.is_last_context_chunk and req.has_draft_tokens: - chunk_size = req.context_chunk_size - remainder = chunk_size % unit_size - # Space left in the last block - space_in_block = 0 if remainder == 0 else (unit_size - - remainder) - - # Check constraints - allowed_space = space_in_block + remainder = req.context_chunk_size % unit_size + remaining_space = 0 if remainder == 0 else unit_size - remainder + + if self.max_context_length is not None: + remaining_context_len = self.max_context_length - req.context_chunk_size + remaining_space = min(remaining_space, + remaining_context_len) + if capacity is not None: - allowed_space = min(allowed_space, - capacity - current_tokens) - - # If we can't fit all draft tokens in the existing block/capacity, discard them - draft_needed = req.num_draft_tokens - if draft_needed > allowed_space: - # In python we might need a method to discard/update draft tokens on req - # req.discard_draft_tokens(draft_needed - allowed_space) - pass - else: - current_tokens += draft_needed + remaining_space = min(remaining_space, + capacity - num_ctx_tokens) + num_ctx_tokens += remaining_space + + draft_discard = req.num_draft_tokens - remaining_space + if draft_discard > 0: + if hasattr(req, "discard_draft_tokens"): + req.discard_draft_tokens(draft_discard) class PyCapacityScheduler: @@ -495,39 +547,6 @@ def schedule_request( raise NotImplementedError( f"Policy {self.policy} not implemented in PyCapacityScheduler") - def _get_req_needed_blocks(self, req: LlmRequest, - is_guaranteed_no_evict: bool) -> int: - """ - Robustly calculate needed blocks. - Tries to use C++ API if safe, otherwise falls back to Python math. - """ - # For GuaranteedNoEvict, we need Total Blocks to completion. - # For MaxUtilization, we typically need incremental blocks (1 step). - - # Scenario A: Disagg Init or Context Init (New Requests) - # The sequence might NOT be in C++ Manager yet. Do math manually to be safe. - if req.state == LlmRequestState.DISAGG_GENERATION_INIT or req.state == LlmRequestState.CONTEXT_INIT: - if is_guaranteed_no_evict: - # Need TOTAL blocks for the whole request - total_len = req.prompt_len + req.max_new_tokens - return math.ceil(total_len / self.tokens_per_block) - else: - # Need just the Context blocks for now (Simplified) - # Or use C++ get_needed_blocks_one_step if it doesn't lookup sequence - return math.ceil(req.prompt_len / self.tokens_per_block) - - # Scenario B: Generation (Running Requests) - # Sequence IS in C++ Manager. Use C++ API for accuracy (shared blocks, beam width, etc). - if req.state == LlmRequestState.GENERATION_IN_PROGRESS or req.state == LlmRequestState.GENERATION_TO_COMPLETE: - if is_guaranteed_no_evict: - return self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - req, self.default_window_size) - else: - return self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, self.default_window_size) - - return 0 - def _schedule_max_requests(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] @@ -561,8 +580,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] - if hasattr(self.kv_cache_manager, "start_scheduling"): - self.kv_cache_manager.start_scheduling() + self.kv_cache_manager_cpp.start_scheduling() # Track free blocks manually in Python to simulate transactional state. stats = self.kv_cache_manager_cpp.get_kv_cache_stats() @@ -573,6 +591,9 @@ def _schedule_max_utilization(self, active_requests: RequestList): while idx < len(cached_active_list): req = cached_active_list[idx] + + # 1. State Filter + # Allow Disagg Gen Init to pass through (matching C++ logic) is_disagg_init = ( req.state == LlmRequestState.DISAGG_GENERATION_INIT) @@ -582,12 +603,21 @@ def _schedule_max_utilization(self, active_requests: RequestList): idx += 1 continue + # 2. Max Requests Check if len(scheduled_requests) >= self.max_num_requests: break - # 3. Try Allocation - needed_blocks = self._get_req_needed_blocks( - req, is_guaranteed_no_evict=False) + # 3. Try Allocation (Python Manual Check) + needed_blocks = 0 + if is_disagg_init: + needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( + req, False, self.default_window_size) + elif req.state == LlmRequestState.GENERATION_IN_PROGRESS: + needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( + req, False, self.default_window_size) + elif req.state == LlmRequestState.CONTEXT_INIT: + needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( + req, False, self.default_window_size) if current_free_blocks >= needed_blocks: current_free_blocks -= needed_blocks @@ -595,7 +625,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): idx += 1 continue - # 4. Backtracking + # 4. Backtracking / Eviction Logic victim_idx = -1 for i in range(len(scheduled_requests) - 1, -1, -1): r = scheduled_requests[i] @@ -609,9 +639,10 @@ def _schedule_max_utilization(self, active_requests: RequestList): paused_requests.append(victim_req) # Reclaim victim's blocks manually - victim_needed = self._get_req_needed_blocks( - victim_req, is_guaranteed_no_evict=False) + victim_needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( + victim_req, False, self.default_window_size) current_free_blocks += victim_needed + # Retry current req without incrementing idx continue else: @@ -623,15 +654,12 @@ def _schedule_max_utilization(self, active_requests: RequestList): def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests: RequestList = [] - - # Pending lists separated to enforce priority: Disagg Init > Context Init pending_disagg_requests: RequestList = [] - pending_context_requests: RequestList = [] + pending_requests: RequestList = [] stats = self.kv_cache_manager_cpp.get_kv_cache_stats() - max_blocks = stats.max_num_blocks - used_blocks = stats.used_num_blocks - available_blocks = max_blocks - used_blocks + # available_blocks represents PHYSICAL free blocks + available_blocks = stats.max_num_blocks - stats.used_num_blocks # --- Pass 1: Running Requests --- for request in active_requests: @@ -639,46 +667,50 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): is_disagg_init = ( req_state == LlmRequestState.DISAGG_GENERATION_INIT) - # Filter valid states (Allow Disagg Init) if not is_disagg_init and ( req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value): continue - # Capacity Check + # If capacity is full, move to pending if len(scheduled_requests) >= self.max_num_requests: if is_disagg_init: pending_disagg_requests.append(request) else: - pending_context_requests.append(request) + pending_requests.append(request) continue - # Prioritize Running Requests (Generation) - # In C++ NoEvict, running requests are scheduled unconditionally in the first pass + # Unconditionally schedule Running Requests (Match C++ NoEvict logic) if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - needed = self._get_req_needed_blocks( - request, is_guaranteed_no_evict=True) + needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( + request, self.default_window_size) + scheduled_requests.append(request) + # Subtract needed blocks from availability. + # This can go negative, effectively reserving space for these requests + # and blocking new ones in Pass 2. available_blocks -= needed else: - # Non-running requests go to pending if is_disagg_init: pending_disagg_requests.append(request) else: - pending_context_requests.append(request) + pending_requests.append(request) - # --- Pass 2: New / Context Requests --- - # Critical: Process Disagg Init requests BEFORE standard Context Init - all_pending = pending_disagg_requests + pending_context_requests + # --- Pass 2: New / Context Requests (Disagg First) --- + all_pending = pending_disagg_requests + pending_requests for request in all_pending: if len(scheduled_requests) >= self.max_num_requests: break - needed_blocks = self._get_req_needed_blocks( - request, is_guaranteed_no_evict=True) + # If running requests have reserved all (or more) than available space, stop. + if available_blocks <= 0: + break + + needed_blocks = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( + request, self.default_window_size) if needed_blocks <= available_blocks: scheduled_requests.append(request) From 2a3a7f27df2b0d222c87bbca8b1a7355543529f2 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Fri, 19 Dec 2025 08:49:24 +0800 Subject: [PATCH 20/25] fix gemma Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/scheduler.py | 181 +++++++++++++------- 1 file changed, 120 insertions(+), 61 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index e5466f6a158..2fa8ec56014 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -3,7 +3,7 @@ from collections import namedtuple from dataclasses import dataclass from enum import Enum -from typing import Optional, Set, Tuple +from typing import Dict, Optional, Set, Tuple from strenum import StrEnum @@ -509,6 +509,7 @@ def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], class PyCapacityScheduler: """ Python implementation of the C++ CapacityScheduler. + Aligned with C++ logic to support Multiple Window Sizes (VSWA). """ def __init__( @@ -534,11 +535,9 @@ def schedule_request( self, active_requests: RequestList ) -> Tuple[RequestList, RequestList, RequestList]: - # 1. Handle No KV Cache Manager -> MaxRequestsScheduler Logic if self.kv_cache_manager is None: return self._schedule_max_requests(active_requests) - # 2. Handle Policies with KV Cache Manager if self.policy == CapacitySchedulerPolicy.MAX_UTILIZATION: return self._schedule_max_utilization(active_requests) elif self.policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: @@ -547,13 +546,95 @@ def schedule_request( raise NotImplementedError( f"Policy {self.policy} not implemented in PyCapacityScheduler") + def _get_initial_available_blocks_map(self) -> Dict[int, int]: + """ + Mimics C++: mKvCacheManager.getBlockManager().getNumFreeBlocksPerWindowSize() + Returns a dict {window_size: free_blocks}. + """ + stats = self.kv_cache_manager_cpp.get_kv_cache_stats() + + # Nanobind binds std::map to python dict + # Property name from binding: .def_rw("num_free_blocks_per_window_size", ...) + free_map = stats.num_free_blocks_per_window_size + + if not free_map: + # Fallback for simple cases or if map is empty (though unlikely in C++) + # Calculate scalar free blocks + free_scalar = stats.max_num_blocks - stats.used_num_blocks + return {self.default_window_size: free_scalar} + + # Ensure we return a copy so we can modify it during scheduling + return dict(free_map) + + def _req_check_and_update_map(self, req: LlmRequest, + available_map: Dict[int, int], + is_guaranteed_no_evict: bool) -> bool: + """ + Checks if a request fits in ALL window sizes tracked in available_map. + If it fits, decrements the map and returns True. + If it doesn't fit, leaves map untouched and returns False. + """ + # 1. Calculate needed blocks for all window sizes + needed_per_window = {} + for window_size in available_map.keys(): + if is_guaranteed_no_evict: + # C++: getRemainingBlocksToCompletion(req, windowSize) + needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( + req, window_size) + else: + # C++: getNeededBlocksOneStep(req, twoStepsLookAhead, windowSize) + needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( + req, False, window_size) + needed_per_window[window_size] = needed + + # 2. Check if fits (All or Nothing) + for window_size, available in available_map.items(): + if needed_per_window[window_size] > available: + return False + + # 3. Commit update + for window_size in available_map.keys(): + available_map[window_size] -= needed_per_window[window_size] + + return True + + def _req_force_update_map(self, req: LlmRequest, available_map: Dict[int, + int], + is_guaranteed_no_evict: bool): + """ + Unconditionally decrements the available blocks (used for Running requests in NoEvict). + Allowed to go negative. + """ + for window_size in available_map.keys(): + if is_guaranteed_no_evict: + needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( + req, window_size) + else: + needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( + req, False, window_size) + + available_map[window_size] -= needed + + def _req_revert_map(self, req: LlmRequest, available_map: Dict[int, int], + is_guaranteed_no_evict: bool): + """ + Reverts a decrement (used for Backtracking in MaxUtilization). + """ + for window_size in available_map.keys(): + if is_guaranteed_no_evict: + needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( + req, window_size) + else: + needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( + req, False, window_size) + + available_map[window_size] += needed + def _schedule_max_requests(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] for req in active_requests: - # 1. State Filter - # Allow Disagg Gen Init to pass through is_disagg_init = ( req.state == LlmRequestState.DISAGG_GENERATION_INIT) @@ -562,11 +643,9 @@ def _schedule_max_requests(self, active_requests: RequestList): or req.state.value >= self.no_schedule_after_state.value): continue - # 2. Max Requests Check if len(scheduled_requests) >= self.max_num_requests: break - # 3. Schedule valid states if (req.state == LlmRequestState.ENCODER_INIT or req.state == LlmRequestState.CONTEXT_INIT or req.state == LlmRequestState.GENERATION_IN_PROGRESS @@ -580,11 +659,11 @@ def _schedule_max_utilization(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] - self.kv_cache_manager_cpp.start_scheduling() + if hasattr(self.kv_cache_manager, "start_scheduling"): + self.kv_cache_manager.start_scheduling() - # Track free blocks manually in Python to simulate transactional state. - stats = self.kv_cache_manager_cpp.get_kv_cache_stats() - current_free_blocks = stats.free_num_blocks + # [FIX] Use Map tracking for multiple window sizes + current_free_blocks_map = self._get_initial_available_blocks_map() cached_active_list = list(active_requests) idx = 0 @@ -592,8 +671,6 @@ def _schedule_max_utilization(self, active_requests: RequestList): while idx < len(cached_active_list): req = cached_active_list[idx] - # 1. State Filter - # Allow Disagg Gen Init to pass through (matching C++ logic) is_disagg_init = ( req.state == LlmRequestState.DISAGG_GENERATION_INIT) @@ -603,29 +680,19 @@ def _schedule_max_utilization(self, active_requests: RequestList): idx += 1 continue - # 2. Max Requests Check if len(scheduled_requests) >= self.max_num_requests: break - # 3. Try Allocation (Python Manual Check) - needed_blocks = 0 - if is_disagg_init: - needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, self.default_window_size) - elif req.state == LlmRequestState.GENERATION_IN_PROGRESS: - needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, self.default_window_size) - elif req.state == LlmRequestState.CONTEXT_INIT: - needed_blocks = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, self.default_window_size) - - if current_free_blocks >= needed_blocks: - current_free_blocks -= needed_blocks + # 3. Try Allocation + # C++ Logic: Checks if it fits in *all* window sizes + if self._req_check_and_update_map(req, + current_free_blocks_map, + is_guaranteed_no_evict=False): scheduled_requests.append(req) idx += 1 continue - # 4. Backtracking / Eviction Logic + # 4. Backtracking (Evict Generation requests only) victim_idx = -1 for i in range(len(scheduled_requests) - 1, -1, -1): r = scheduled_requests[i] @@ -638,12 +705,12 @@ def _schedule_max_utilization(self, active_requests: RequestList): victim_req = scheduled_requests.pop(victim_idx) paused_requests.append(victim_req) - # Reclaim victim's blocks manually - victim_needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - victim_req, False, self.default_window_size) - current_free_blocks += victim_needed + # Revert victim's usage in the map + self._req_revert_map(victim_req, + current_free_blocks_map, + is_guaranteed_no_evict=False) - # Retry current req without incrementing idx + # Retry current req (do NOT increment idx) continue else: # No victim found, and current request doesn't fit. Stop. @@ -655,11 +722,11 @@ def _schedule_max_utilization(self, active_requests: RequestList): def _schedule_guaranteed_no_evict(self, active_requests: RequestList): scheduled_requests: RequestList = [] pending_disagg_requests: RequestList = [] - pending_requests: RequestList = [] + pending_context_requests: RequestList = [] - stats = self.kv_cache_manager_cpp.get_kv_cache_stats() - # available_blocks represents PHYSICAL free blocks - available_blocks = stats.max_num_blocks - stats.used_num_blocks + # [FIX] Use Map tracking for multiple window sizes + # Note: C++ NoEvictScheduledBlocksManager initializes with getNumFreeBlocksPerWindowSize() + available_blocks_map = self._get_initial_available_blocks_map() # --- Pass 1: Running Requests --- for request in active_requests: @@ -672,51 +739,44 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): or req_state.value >= self.no_schedule_after_state.value): continue - # If capacity is full, move to pending if len(scheduled_requests) >= self.max_num_requests: if is_disagg_init: pending_disagg_requests.append(request) else: - pending_requests.append(request) + pending_context_requests.append(request) continue - # Unconditionally schedule Running Requests (Match C++ NoEvict logic) + # Unconditionally schedule Running Requests if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - request, self.default_window_size) - scheduled_requests.append(request) - # Subtract needed blocks from availability. - # This can go negative, effectively reserving space for these requests - # and blocking new ones in Pass 2. - available_blocks -= needed + + # [FIX] Update Map unconditionally (can go negative) + self._req_force_update_map(request, + available_blocks_map, + is_guaranteed_no_evict=True) else: if is_disagg_init: pending_disagg_requests.append(request) else: - pending_requests.append(request) + pending_context_requests.append(request) # --- Pass 2: New / Context Requests (Disagg First) --- - all_pending = pending_disagg_requests + pending_requests + all_pending = pending_disagg_requests + pending_context_requests for request in all_pending: if len(scheduled_requests) >= self.max_num_requests: break - # If running requests have reserved all (or more) than available space, stop. - if available_blocks <= 0: - break - - needed_blocks = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - request, self.default_window_size) - - if needed_blocks <= available_blocks: + # [FIX] Check using Map logic + # C++ enoughAvailableBlocks checks: needed <= available for ALL window sizes + if self._req_check_and_update_map(request, + available_blocks_map, + is_guaranteed_no_evict=True): scheduled_requests.append(request) - available_blocks -= needed_blocks else: - # Head-of-line blocking (Standard NoEvict behavior) + # Head-of-line blocking break return self._classify_output(active_requests, scheduled_requests, []) @@ -730,7 +790,6 @@ def _classify_output(self, active_requests, scheduled_requests, scheduled_ids = set(r.request_id for r in scheduled_requests) paused_ids = set(r.request_id for r in paused_requests) - # Identify running requests that were implicitly paused (dropped) for req in active_requests: if (req.request_id not in scheduled_ids and req.request_id not in paused_ids From 2f30b99f230a69b10868daab89664c23ab70ff03 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Fri, 19 Dec 2025 10:23:42 +0800 Subject: [PATCH 21/25] fix lora Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 5 +- tensorrt_llm/_torch/pyexecutor/scheduler.py | 76 ++++++++++++++++----- 2 files changed, 60 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 7477953638d..0f6252499c2 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -808,15 +808,12 @@ def create_py_executor_instance( use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1" if use_python_scheduler: - if peft_cache_manager is not None: - logger.warning( - "PeftCacheManager is currently ignored by Python Scheduler (Phase 1)." - ) scheduler = SimpleUnifiedScheduler( max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, kv_cache_manager=kv_cache_manager, + peft_cache_manager=peft_cache_manager, scheduler_policy=scheduler_config.capacity_scheduler_policy, ctx_chunk_config=ctx_chunk_config) else: diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 2fa8ec56014..be610431c31 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -516,6 +516,7 @@ def __init__( self, max_num_requests: int, kv_cache_manager, + peft_cache_manager=None, scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. MAX_UTILIZATION, no_schedule_until_state=LlmRequestState.CONTEXT_INIT, @@ -523,6 +524,7 @@ def __init__( ): self.max_num_requests = max_num_requests self.kv_cache_manager = kv_cache_manager + self.peft_cache_manager = peft_cache_manager self.policy = scheduler_policy self.no_schedule_until_state = no_schedule_until_state self.no_schedule_after_state = no_schedule_after_state @@ -531,6 +533,12 @@ def __init__( self.kv_cache_manager_cpp = kv_cache_manager.impl self.default_window_size = self.kv_cache_manager.max_seq_len + if self.peft_cache_manager: + self.peft_cache_manager_cpp = self.peft_cache_manager.impl + self.max_peft_pages = self.peft_cache_manager_cpp.max_device_pages + else: + self.max_peft_pages = float('inf') # Effectively infinite + def schedule_request( self, active_requests: RequestList ) -> Tuple[RequestList, RequestList, RequestList]: @@ -659,8 +667,7 @@ def _schedule_max_utilization(self, active_requests: RequestList): scheduled_requests: RequestList = [] paused_requests: RequestList = [] - if hasattr(self.kv_cache_manager, "start_scheduling"): - self.kv_cache_manager.start_scheduling() + self.kv_cache_manager_cpp.start_scheduling() # [FIX] Use Map tracking for multiple window sizes current_free_blocks_map = self._get_initial_available_blocks_map() @@ -724,10 +731,13 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): pending_disagg_requests: RequestList = [] pending_context_requests: RequestList = [] - # [FIX] Use Map tracking for multiple window sizes - # Note: C++ NoEvictScheduledBlocksManager initializes with getNumFreeBlocksPerWindowSize() + # KV Cache Resource Tracking available_blocks_map = self._get_initial_available_blocks_map() + # PEFT Resource Tracking + claimed_peft_pages = 0 + uniq_task_ids: Set[int] = set() + # --- Pass 1: Running Requests --- for request in active_requests: req_state = request.state @@ -746,39 +756,69 @@ def _schedule_guaranteed_no_evict(self, active_requests: RequestList): pending_context_requests.append(request) continue - # Unconditionally schedule Running Requests if (req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - scheduled_requests.append(request) - - # [FIX] Update Map unconditionally (can go negative) + # 1. Update KV Cache Map (Unconditional) self._req_force_update_map(request, available_blocks_map, is_guaranteed_no_evict=True) + + # 2. Update PEFT Usage + # C++: if (isNewTask) claimedPeftPages += determineNumPages(req); + if self.peft_cache_manager and request.lora_task_id is not None: + task_id = request.lora_task_id + if task_id not in uniq_task_ids: + # Binding check: determine_num_pages + pages = self.peft_cache_manager_cpp.determine_num_pages( + request) + claimed_peft_pages += pages + uniq_task_ids.add(task_id) + + scheduled_requests.append(request) else: if is_disagg_init: pending_disagg_requests.append(request) else: pending_context_requests.append(request) - # --- Pass 2: New / Context Requests (Disagg First) --- + # --- Pass 2: New / Context Requests --- + available_peft_pages = self.max_peft_pages - claimed_peft_pages all_pending = pending_disagg_requests + pending_context_requests for request in all_pending: if len(scheduled_requests) >= self.max_num_requests: break - # [FIX] Check using Map logic - # C++ enoughAvailableBlocks checks: needed <= available for ALL window sizes - if self._req_check_and_update_map(request, - available_blocks_map, - is_guaranteed_no_evict=True): - scheduled_requests.append(request) - else: - # Head-of-line blocking + # 1. Check PEFT Capacity + needed_peft_pages = 0 + is_new_task = False + task_id = None + + if self.peft_cache_manager and request.lora_task_id is not None: + task_id = request.lora_task_id + is_new_task = (task_id not in uniq_task_ids) + if is_new_task: + needed_peft_pages = self.peft_cache_manager_cpp.determine_num_pages( + request) + if needed_peft_pages > available_peft_pages: + # Not enough PEFT memory + break # Head-of-line blocking + + # 2. Check KV Cache Capacity + if not self._req_check_and_update_map( + request, available_blocks_map, is_guaranteed_no_evict=True): + # Not enough KV blocks break + # 3. Commit Schedule + scheduled_requests.append(request) + + # Commit PEFT usage + if is_new_task: + available_peft_pages -= needed_peft_pages + uniq_task_ids.add(task_id) + return self._classify_output(active_requests, scheduled_requests, []) def _classify_output(self, active_requests, scheduled_requests, @@ -812,6 +852,7 @@ def __init__( max_batch_size: int, max_num_tokens: int, kv_cache_manager, + peft_cache_manager, scheduler_policy: CapacitySchedulerPolicy, ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None, ): @@ -819,6 +860,7 @@ def __init__( self.capacity_scheduler = PyCapacityScheduler( max_num_requests=max_batch_size, kv_cache_manager=kv_cache_manager, + peft_cache_manager=peft_cache_manager, scheduler_policy=scheduler_policy) # 2. Initialize Python MicroBatch Scheduler From 411c2542b7294feaf65c3045320e220b59ad4924 Mon Sep 17 00:00:00 2001 From: Lanyu Liao Date: Wed, 24 Dec 2025 01:16:30 -0800 Subject: [PATCH 22/25] implement scheduler using python Signed-off-by: Lanyu Liao --- .../nanobind/batch_manager/bindings.cpp | 15 +- .../nanobind/batch_manager/kvCacheManager.cpp | 9 +- .../pybind/batch_manager/bindings.cpp | 15 +- .../pybind/batch_manager/kvCacheManager.cpp | 9 +- tensorrt_llm/_torch/pyexecutor/_util.py | 4 +- tensorrt_llm/_torch/pyexecutor/scheduler.py | 991 ++++++++++++------ 6 files changed, 738 insertions(+), 305 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 17c27f43bed..b75f8e8bd69 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -252,7 +252,20 @@ void initBindings(nb::module_& m) }) .def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) .def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics) - .def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel); + .def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel) + .def("get_unique_tokens", nb::overload_cast(&GenLlmReq::getUniqueTokens, nb::const_), + nb::arg("beam")) + .def("get_unique_tokens", nb::overload_cast<>(&GenLlmReq::getUniqueTokens, nb::const_)) + .def("get_encoder_unique_tokens", + [](GenLlmReq& self) + { + auto const& encoderUniqueTokens = self.getEncoderUniqueTokens(); + if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value()) + { + return std::optional(*encoderUniqueTokens.value()); + } + return std::optional(std::nullopt); + }); nb::class_(m, "LlmRequest", nb::dynamic_attr()) .def( diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 7a3bcae7cf1..2fb6ad95a9e 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -481,6 +481,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard()) .def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse, nb::call_guard()) + .def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, nb::arg("unique_tokens"), + nb::arg("llm_request"), nb::call_guard()) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard()) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds, nb::call_guard()) @@ -524,7 +526,12 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true, nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr, nb::arg("enable_indexer_k_cache") = false, nb::arg("indexer_k_cache_quant_block_size") = 128, - nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard()); + nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard()) + .def( + "scheduling_has_free_blocks", + [](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize) + { return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); }, + nb::arg("num_required"), nb::arg("window_size"), nb::call_guard()); } void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 1d98b0c623a..510571613e7 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -258,7 +258,20 @@ void initBindings(pybind11::module_& m) }) .def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics) - .def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel); + .def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel) + .def("get_unique_tokens", py::overload_cast(&GenLlmReq::getUniqueTokens, py::const_), + py::arg("beam")) + .def("get_unique_tokens", py::overload_cast<>(&GenLlmReq::getUniqueTokens, py::const_)) + .def("get_encoder_unique_tokens", + [](GenLlmReq& self) + { + auto const& encoderUniqueTokens = self.getEncoderUniqueTokens(); + if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value()) + { + return std::optional(*encoderUniqueTokens.value()); + } + return std::optional(std::nullopt); + }); py::classh(m, "LlmRequest", pybind11::dynamic_attr()) .def(py::init<>( diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 6ab03315e1a..2ef12236795 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -485,6 +485,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard()) .def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse, py::call_guard()) + .def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, py::arg("unique_tokens"), + py::arg("llm_request"), py::call_guard()) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard()) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds, py::call_guard()) @@ -519,7 +521,12 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, py::arg("kv_connector_manager") = nullptr, py::arg("enable_indexer_k_cache") = false, py::arg("indexer_k_cache_quant_block_size") = 128, py::arg("indexer_k_cache_index_head_dim") = 0, - py::call_guard()); + py::call_guard()) + .def( + "scheduling_has_free_blocks", + [](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize) + { return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); }, + py::arg("num_required"), py::arg("window_size"), py::call_guard()); } void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 0f6252499c2..11f6ac4da0d 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -808,14 +808,14 @@ def create_py_executor_instance( use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1" if use_python_scheduler: - scheduler = SimpleUnifiedScheduler( max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, kv_cache_manager=kv_cache_manager, peft_cache_manager=peft_cache_manager, scheduler_policy=scheduler_config.capacity_scheduler_policy, - ctx_chunk_config=ctx_chunk_config) + ctx_chunk_config=ctx_chunk_config, + two_step_lookahead=mapping.has_pp()) else: capacity_scheduler = BindCapacityScheduler( scheduler_capacity, diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index be610431c31..d4533c3a70c 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -9,6 +9,7 @@ from tensorrt_llm.bindings import internal as tb_internal from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy +from tensorrt_llm.logger import logger # Assuming these imports exist in your environment from .llm_request import LlmRequest, LlmRequestState @@ -259,12 +260,35 @@ def __init__( max_batch_size: int, max_num_tokens: Optional[int] = None, ctx_chunk_config: Optional[ContextChunkingConfig] = None, + no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, + no_schedule_after_state: LlmRequestState = LlmRequestState. + GENERATION_TO_COMPLETE, ): super().__init__() self.max_batch_size = max_batch_size self.max_num_tokens = max_num_tokens self.ctx_chunk_config = ctx_chunk_config self.max_context_length = max_num_tokens + # Match C++ MicroBatchScheduler defaults (see algorithms.cpp line 68-70) + self.no_schedule_until_state = no_schedule_until_state + self.no_schedule_after_state = no_schedule_after_state + + def _has_reached_state(self, req: LlmRequest, + target_state: LlmRequestState) -> bool: + """Check if request has reached the target state.""" + # C++ equivalent: req->hasReachedState(state) + return req.state.value >= target_state.value + + def _can_be_scheduled(self, req: LlmRequest) -> bool: + """ + Check if request is within the schedulable state range. + C++ reference: microBatchScheduler.cpp line 192-195 + """ + if not self._has_reached_state(req, self.no_schedule_until_state): + return False + if self._has_reached_state(req, self.no_schedule_after_state): + return False + return True def schedule( self, active_requests: RequestList, @@ -289,23 +313,26 @@ def schedule( if req.request_id in inflight_request_ids: continue + # Skip if request cannot be scheduled yet or should no longer be scheduled + # C++ reference: microBatchScheduler.cpp line 192-195 + if not self._can_be_scheduled(req): + continue + req_num_tokens = 0 - # --- A. Encoder Request Handling (Previously Missing) --- + # --- A. Encoder Request Handling --- if req.state == LlmRequestState.ENCODER_INIT: - # C++: reqNumTokens = llmReq->getEncoderOutputLen(); req_num_tokens = req.encoder_output_len - if self.max_context_length is not None and req_num_tokens > self.max_context_length: - # C++ does TLLM_CHECK here. We skip or log. - continue + assert self.max_context_length is None or req_num_tokens <= self.max_context_length, \ + f"The number of encoder tokens ({req_num_tokens}) exceeds the limit value ({self.max_context_length})" - # Check Batch Token Budget if self.max_num_tokens is not None and (batch_num_tokens + req_num_tokens > self.max_num_tokens): break + logger.debug(f"encoder request scheduled: ID {req.request_id}") context_requests.append(req) batch_num_tokens += req_num_tokens @@ -313,24 +340,25 @@ def schedule( elif req.state == LlmRequestState.CONTEXT_INIT: if not self.ctx_chunk_config: # No Chunking: Schedule full context - # C++: getNumTokens(beam) + (hasDraft ? getNumDraftTokens : 0) - base_tokens = req.context_remaining_length # effectively getNumTokens(0) + # C++ uses getNumTokens(beam=0) which is tokens.size() - numPreDecodedTokens + base_tokens = req.get_num_tokens(0) draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 req_num_tokens = base_tokens + draft_tokens - if self.max_context_length is not None and req_num_tokens > self.max_context_length: - continue + assert self.max_context_length is None or req_num_tokens <= self.max_context_length, \ + f"The number of context tokens ({req_num_tokens}) exceeds the limit value ({self.max_context_length})" if self.max_num_tokens is not None and ( batch_num_tokens + req_num_tokens > self.max_num_tokens): break + logger.debug( + f"context request scheduled: ID {req.request_id}") context_requests.append(req) batch_num_tokens += req_num_tokens else: # Chunking Enabled: Tentative schedule - # C++: setContextChunkSize(remaining); reqNumTokens = size + draft req.context_chunk_size = req.context_remaining_length draft_tokens = req.num_draft_tokens if ( @@ -338,18 +366,23 @@ def schedule( and req.has_draft_tokens) else 0 req_num_tokens = req.context_chunk_size + draft_tokens - # C++: Check maxContextLength constraints if self.max_context_length is not None: if self.max_context_length < req_num_tokens: req_num_tokens = self.max_context_length all_context_requests_fit = False + logger.debug( + f"contexts-to-be-chunked request scheduled: ID {req.request_id}" + ) contexts_to_be_chunked.append(req) num_chunked_tokens += req_num_tokens # --- C. Generation Request Handling --- else: - beam_width = req.sampling_config.beam_width + # C++ uses getBeamWidthByIter() which returns dynamic beam width + # during beam search (1->2->3->...->beamWidth) + beam_width = req.get_beam_width_by_iter( + for_next_iteration=False) req_num_tokens = beam_width + req.num_draft_tokens if self.max_num_tokens is not None and (batch_num_tokens + @@ -357,13 +390,19 @@ def schedule( > self.max_num_tokens): break - # Beam Width Consistency Check (C++ Logic) + # Beam Width Consistency Check if scheduled_beam_width == 0: scheduled_beam_width = beam_width elif scheduled_beam_width != beam_width: - # Skip requests with different beam width in this batch + logger.debug( + f"generation request skipped: ID {req.request_id} since its " + f"beam width ({beam_width}) is different from scheduled ones " + f"({scheduled_beam_width})") continue + logger.debug( + f"generation request scheduled: ID {req.request_id} " + f"with beam width {beam_width}") generation_requests.append(req) batch_num_tokens += req_num_tokens @@ -379,28 +418,80 @@ def schedule( # 3. Apply Chunking Strategy if needed if not all_context_requests_fit and contexts_to_be_chunked: - if not self.ctx_chunk_config: - pass # Error in C++: "If chunking not enabled..." - else: - remaining_capacity = ( - self.max_num_tokens - batch_num_tokens - ) if self.max_num_tokens is not None else None + assert self.ctx_chunk_config is not None, \ + "If chunking is not enabled, context scheduling should be completed." + remaining_capacity = ( + self.max_num_tokens - + batch_num_tokens) if self.max_num_tokens is not None else None - self._set_ctx_requests_chunk_size(contexts_to_be_chunked, - remaining_capacity) + self._set_ctx_requests_chunk_size(contexts_to_be_chunked, + remaining_capacity) # 4. Finalize Chunked Requests for req in contexts_to_be_chunked: if req.context_chunk_size > 0: context_requests.append(req) - # C++: batchNumTokens += chunk size batch_num_tokens += req.context_chunk_size - - # Note: C++ calls utils::sortRequests here. Python lists preserve order, - # usually acceptable unless specific downstream kernel requirements exist. + logger.debug(f"context request scheduled: ID {req.request_id}, " + f"chunk size {req.context_chunk_size}") + + # Sort requests for consistency with C++ + # C++ reference: utils::sortRequests in inflightBatchingUtils.cpp + self._sort_requests(context_requests, generation_requests, + not all_context_requests_fit) + + # Summary logs + logger.debug(f"batchSize (num ctx/enc requests + num gen requests): " + f"{len(context_requests) + len(generation_requests)}") + logger.debug(f"batchNumTokens / maxNumTokens: {batch_num_tokens} / " + f"{self.max_num_tokens or 0}") + logger.debug( + f"[Summary] Micro Batch scheduler schedules {len(context_requests)} " + f"context/encoder requests, {len(generation_requests)} generation requests. " + f"{len(inflight_request_ids)} requests inflight with the model already" + ) return context_requests, generation_requests + def _sort_requests(self, context_requests: RequestList, + generation_requests: RequestList, + chunks_present: bool) -> None: + """ + Sort requests for consistency with C++. + C++ reference: utils::sortRequests in inflightBatchingUtils.cpp + + 1. If chunks are present, move context requests that reached the last + context chunk to the end of the vector. + 2. Sort all requests by lora task id for performance. + """ + + def get_lora_task_id(req: LlmRequest) -> int: + # Return lora_task_id or a large value if not set + lora_id = getattr(req, 'lora_task_id', None) + if lora_id is None: + return float('inf') + return lora_id + + if chunks_present: + # Partition: non-last-chunk first, last-chunk at end + not_last_chunk = [ + r for r in context_requests if not r.is_last_context_chunk + ] + last_chunk = [ + r for r in context_requests if r.is_last_context_chunk + ] + # Sort each group by lora_task_id + not_last_chunk.sort(key=get_lora_task_id) + last_chunk.sort(key=get_lora_task_id) + # Rebuild the list in-place + context_requests.clear() + context_requests.extend(not_last_chunk) + context_requests.extend(last_chunk) + else: + context_requests.sort(key=get_lora_task_id) + + generation_requests.sort(key=get_lora_task_id) + def _set_ctx_requests_chunk_size(self, requests: RequestList, capacity: Optional[int]): # C++: Resets all chunk sizes to 0 at start @@ -414,6 +505,8 @@ def _set_ctx_requests_chunk_size(self, requests: RequestList, self._chunk_equal_progress(requests, capacity, unit_size) elif policy == ChunkingPolicy.FIRST_COME_FIRST_SERVED: self._chunk_fcfs(requests, capacity, unit_size) + else: + raise ValueError(f"Invalid chunking policy: {policy}") self._fit_draft_tokens(requests, capacity, unit_size) @@ -502,347 +595,642 @@ def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], draft_discard = req.num_draft_tokens - remaining_space if draft_discard > 0: + logger.debug(f"Discarding {draft_discard} draft tokens") if hasattr(req, "discard_draft_tokens"): req.discard_draft_tokens(draft_discard) -class PyCapacityScheduler: +class SchedulerPolicyBase(ABC): """ - Python implementation of the C++ CapacityScheduler. - Aligned with C++ logic to support Multiple Window Sizes (VSWA). + Abstract base class for capacity scheduler policies. + Each policy implements its own scheduling logic. """ - def __init__( - self, - max_num_requests: int, - kv_cache_manager, - peft_cache_manager=None, - scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. - MAX_UTILIZATION, - no_schedule_until_state=LlmRequestState.CONTEXT_INIT, - no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE, - ): - self.max_num_requests = max_num_requests - self.kv_cache_manager = kv_cache_manager - self.peft_cache_manager = peft_cache_manager - self.policy = scheduler_policy - self.no_schedule_until_state = no_schedule_until_state - self.no_schedule_after_state = no_schedule_after_state + @abstractmethod + def schedule( + self, scheduler: 'PyCapacityScheduler', + active_requests: RequestList) -> Tuple[RequestList, RequestList]: + """ + Schedule requests according to the policy. - if self.kv_cache_manager is not None: - self.kv_cache_manager_cpp = kv_cache_manager.impl - self.default_window_size = self.kv_cache_manager.max_seq_len + Args: + scheduler: The capacity scheduler instance (for accessing shared state) + active_requests: List of active requests to schedule - if self.peft_cache_manager: - self.peft_cache_manager_cpp = self.peft_cache_manager.impl - self.max_peft_pages = self.peft_cache_manager_cpp.max_device_pages - else: - self.max_peft_pages = float('inf') # Effectively infinite + Returns: + Tuple of (scheduled_requests, paused_requests) + """ + raise NotImplementedError - def schedule_request( - self, active_requests: RequestList - ) -> Tuple[RequestList, RequestList, RequestList]: - if self.kv_cache_manager is None: - return self._schedule_max_requests(active_requests) +class MaxRequestsPolicy(SchedulerPolicyBase): + """ + MaxRequestsScheduler: Simple request count limiting without KV cache checks. + C++ reference: capacityScheduler.cpp:154-176 + """ - if self.policy == CapacitySchedulerPolicy.MAX_UTILIZATION: - return self._schedule_max_utilization(active_requests) - elif self.policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: - return self._schedule_guaranteed_no_evict(active_requests) - else: - raise NotImplementedError( - f"Policy {self.policy} not implemented in PyCapacityScheduler") + def schedule( + self, scheduler: 'PyCapacityScheduler', + active_requests: RequestList) -> Tuple[RequestList, RequestList]: + scheduled_requests: RequestList = [] - def _get_initial_available_blocks_map(self) -> Dict[int, int]: - """ - Mimics C++: mKvCacheManager.getBlockManager().getNumFreeBlocksPerWindowSize() - Returns a dict {window_size: free_blocks}. - """ - stats = self.kv_cache_manager_cpp.get_kv_cache_stats() + for req in active_requests: + if not scheduler._can_be_scheduled(req): + continue - # Nanobind binds std::map to python dict - # Property name from binding: .def_rw("num_free_blocks_per_window_size", ...) - free_map = stats.num_free_blocks_per_window_size + if len(scheduled_requests) >= scheduler.max_num_requests: + break - if not free_map: - # Fallback for simple cases or if map is empty (though unlikely in C++) - # Calculate scalar free blocks - free_scalar = stats.max_num_blocks - stats.used_num_blocks - return {self.default_window_size: free_scalar} + if (req.is_encoder_init_state or req.is_context_init_state + or req.is_generation_in_progress_state): + scheduled_requests.append(req) - # Ensure we return a copy so we can modify it during scheduling - return dict(free_map) + return scheduled_requests, [] - def _req_check_and_update_map(self, req: LlmRequest, - available_map: Dict[int, int], - is_guaranteed_no_evict: bool) -> bool: - """ - Checks if a request fits in ALL window sizes tracked in available_map. - If it fits, decrements the map and returns True. - If it doesn't fit, leaves map untouched and returns False. - """ - # 1. Calculate needed blocks for all window sizes - needed_per_window = {} - for window_size in available_map.keys(): - if is_guaranteed_no_evict: - # C++: getRemainingBlocksToCompletion(req, windowSize) - needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - req, window_size) - else: - # C++: getNeededBlocksOneStep(req, twoStepsLookAhead, windowSize) - needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, window_size) - needed_per_window[window_size] = needed - - # 2. Check if fits (All or Nothing) - for window_size, available in available_map.items(): - if needed_per_window[window_size] > available: - return False - # 3. Commit update - for window_size in available_map.keys(): - available_map[window_size] -= needed_per_window[window_size] +class GuaranteedNoEvictPolicy(SchedulerPolicyBase): + """ + GuaranteedNoEvictScheduler: Reserve blocks for requests to complete without eviction. + C++ reference: capacityScheduler.cpp:194-331 + """ - return True + def __init__(self, static_batch: bool = False): + self.static_batch = static_batch - def _req_force_update_map(self, req: LlmRequest, available_map: Dict[int, - int], - is_guaranteed_no_evict: bool): - """ - Unconditionally decrements the available blocks (used for Running requests in NoEvict). - Allowed to go negative. - """ - for window_size in available_map.keys(): - if is_guaranteed_no_evict: - needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - req, window_size) - else: - needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, window_size) + def schedule( + self, scheduler: 'PyCapacityScheduler', + active_requests: RequestList) -> Tuple[RequestList, RequestList]: + scheduled_requests: RequestList = [] + max_peft_pages = scheduler._get_max_peft_pages() - available_map[window_size] -= needed + skipping_is_relevant = scheduler._is_skipping_relevant() - def _req_revert_map(self, req: LlmRequest, available_map: Dict[int, int], - is_guaranteed_no_evict: bool): - """ - Reverts a decrement (used for Backtracking in MaxUtilization). - """ - for window_size in available_map.keys(): - if is_guaranteed_no_evict: - needed = self.kv_cache_manager_cpp.get_remaining_blocks_to_completion( - req, window_size) - else: - needed = self.kv_cache_manager_cpp.get_needed_blocks_one_step( - req, False, window_size) + newly_contributed_context_blocks: Set = set() + newly_contributed_cross_context_blocks: Set = set() + if not self.static_batch and skipping_is_relevant: + newly_contributed_context_blocks, newly_contributed_cross_context_blocks = \ + scheduler._prefill_contributed_blocks(active_requests) - available_map[window_size] += needed + reserved_blocks = NoEvictScheduledBlocksManager( + scheduler.kv_cache_manager) + reserved_cross_blocks: Optional[NoEvictScheduledBlocksManager] = None + if scheduler.cross_kv_cache_manager is not None: + reserved_cross_blocks = NoEvictScheduledBlocksManager( + scheduler.cross_kv_cache_manager) - def _schedule_max_requests(self, active_requests: RequestList): - scheduled_requests: RequestList = [] - paused_requests: RequestList = [] + claimed_peft_pages = 0 + uniq_task_ids: Set[int] = set() - for req in active_requests: - is_disagg_init = ( - req.state == LlmRequestState.DISAGG_GENERATION_INIT) + pending_requests: RequestList = [] + pending_dis_gen_init_requests: RequestList = [] - if not is_disagg_init and ( - req.state.value < self.no_schedule_until_state.value - or req.state.value >= self.no_schedule_after_state.value): + # First pass: process in-progress generation and classify requests + for req in active_requests: + if not scheduler._can_be_scheduled_with_disagg_exception(req): continue - if len(scheduled_requests) >= self.max_num_requests: + if len(scheduled_requests) >= scheduler.max_num_requests: break - if (req.state == LlmRequestState.ENCODER_INIT - or req.state == LlmRequestState.CONTEXT_INIT - or req.state == LlmRequestState.GENERATION_IN_PROGRESS - or is_disagg_init): + if req.is_generation_in_progress_state: scheduled_requests.append(req) + reserved_blocks.decrement_reserved_blocks(req) + if reserved_cross_blocks is not None: + reserved_cross_blocks.decrement_reserved_blocks(req) - return self._classify_output(active_requests, scheduled_requests, - paused_requests) + lora_task_id, is_new_task, peft_pages = scheduler._get_peft_task_info( + req, uniq_task_ids) + if is_new_task: + claimed_peft_pages += peft_pages + uniq_task_ids.add(lora_task_id) - def _schedule_max_utilization(self, active_requests: RequestList): - scheduled_requests: RequestList = [] - paused_requests: RequestList = [] + elif req.is_disagg_generation_init_state: + pending_dis_gen_init_requests.append(req) + else: + pending_requests.append(req) + + # Second pass: process pending requests + if not self.static_batch or len(scheduled_requests) == 0: + available_peft_pages = max_peft_pages - claimed_peft_pages + + for requests in [pending_dis_gen_init_requests, pending_requests]: + for req in requests: + if (not self.static_batch and skipping_is_relevant + and not req.is_disagg_generation_init_state + and scheduler._beneficial_to_skip( + req, newly_contributed_context_blocks, + newly_contributed_cross_context_blocks)): + continue - self.kv_cache_manager_cpp.start_scheduling() + if len(scheduled_requests) >= scheduler.max_num_requests: + break - # [FIX] Use Map tracking for multiple window sizes - current_free_blocks_map = self._get_initial_available_blocks_map() + if req.is_context_init_state or req.is_disagg_generation_init_state: + enough_blocks = reserved_blocks.enough_available_blocks( + req) + enough_cross_blocks = True + if reserved_cross_blocks is not None: + enough_cross_blocks = reserved_cross_blocks.enough_available_blocks( + req) + + lora_task_id, is_new_task, needed_peft_pages = scheduler._get_peft_task_info( + req, uniq_task_ids) + + if (enough_blocks and enough_cross_blocks + and needed_peft_pages <= available_peft_pages): + scheduled_requests.append(req) + reserved_blocks.decrement_reserved_blocks(req) + if reserved_cross_blocks is not None: + reserved_cross_blocks.decrement_reserved_blocks( + req) + available_peft_pages -= needed_peft_pages + if is_new_task: + uniq_task_ids.add(lora_task_id) + elif not enough_blocks or not enough_cross_blocks: + break + + return scheduled_requests, [] + + +class MaxUtilizationPolicy(SchedulerPolicyBase): + """ + MaxUtilizationScheduler: Maximize utilization, may pause started requests. + C++ reference: capacityScheduler.cpp:341-425 + """ - cached_active_list = list(active_requests) - idx = 0 + def schedule( + self, scheduler: 'PyCapacityScheduler', + active_requests: RequestList) -> Tuple[RequestList, RequestList]: + scheduler.kv_cache_manager.impl.start_scheduling() - while idx < len(cached_active_list): - req = cached_active_list[idx] + skipping_is_relevant = scheduler._is_skipping_relevant() - is_disagg_init = ( - req.state == LlmRequestState.DISAGG_GENERATION_INIT) + scheduled_blocks_manager = MaxUtilizationScheduledBlocksManager( + scheduler.kv_cache_manager, scheduler.two_step_lookahead) - if not is_disagg_init and ( - req.state.value < self.no_schedule_until_state.value - or req.state.value >= self.no_schedule_after_state.value): - idx += 1 - continue + num_scheduled_peft_pages = 0 + seen_task_ids: Set[int] = set() - if len(scheduled_requests) >= self.max_num_requests: - break + newly_contributed_context_blocks, _ = scheduler._prefill_contributed_blocks( + active_requests) - # 3. Try Allocation - # C++ Logic: Checks if it fits in *all* window sizes - if self._req_check_and_update_map(req, - current_free_blocks_map, - is_guaranteed_no_evict=False): - scheduled_requests.append(req) - idx += 1 + def is_started_request(req: LlmRequest) -> bool: + if not scheduler._can_be_scheduled(req): + return False + return ((req.is_context_init_state + and not req.is_first_context_chunk) + or req.is_generation_in_progress_state) + + scheduled_requests: RequestList = [] + paused_requests: RequestList = [] + + requests_list = list(active_requests) + req_it_end = len(requests_list) + req_it = 0 + + while req_it < req_it_end: + req = requests_list[req_it] + logger.debug( + f"MaxUtilizationScheduler: scheduling request ID {req.request_id}" + ) + + if not scheduler._can_be_scheduled_with_disagg_exception(req): + logger.debug( + f"MaxUtilizationScheduler: request ID {req.request_id} " + "cannot / should not be scheduled") + req_it += 1 + continue + + if (skipping_is_relevant and scheduler._beneficial_to_skip( + req, newly_contributed_context_blocks, set())): + req_it += 1 continue - # 4. Backtracking (Evict Generation requests only) - victim_idx = -1 - for i in range(len(scheduled_requests) - 1, -1, -1): - r = scheduled_requests[i] - if r.state == LlmRequestState.GENERATION_IN_PROGRESS: - victim_idx = i + was_scheduled = self._try_scheduling_request( + scheduler, req, scheduled_requests, scheduled_blocks_manager, + num_scheduled_peft_pages, seen_task_ids) + + if was_scheduled: + logger.debug( + f"MaxUtilizationScheduler: request ID {req.request_id} -> start" + ) + req_it += 1 + else: + last_started_idx = None + for i in range(req_it_end - 1, req_it - 1, -1): + if is_started_request(requests_list[i]): + last_started_idx = i + break + + if last_started_idx is not None: + paused_req = requests_list[last_started_idx] + scheduler.kv_cache_manager.impl.scheduling_remove_sequence( + paused_req.py_request_id) + paused_requests.append(paused_req) + logger.debug( + f"MaxUtilizationScheduler: request ID {paused_req.request_id} -> pause" + ) + req_it_end = last_started_idx + else: break - if victim_idx != -1: - # Found a victim. Evict it. - victim_req = scheduled_requests.pop(victim_idx) - paused_requests.append(victim_req) + return scheduled_requests, paused_requests + + def _try_scheduling_request( + self, scheduler: 'PyCapacityScheduler', req: LlmRequest, + scheduled_requests: RequestList, + scheduled_blocks_manager: 'MaxUtilizationScheduledBlocksManager', + num_scheduled_peft_pages: int, seen_task_ids: Set[int]) -> bool: + if len(scheduled_requests) >= scheduler.max_num_requests: + return False + + lora_task_id, is_new_task, num_required_peft_pages = scheduler._get_peft_task_info( + req, seen_task_ids) + logger.debug(f"MaxUtilizationScheduler: request ID {req.request_id} " + f"required peft pages: {num_required_peft_pages}") + + blocks_if_scheduled = scheduled_blocks_manager.prepare_blocks_if_schedulable( + req) + fits_kv_cache = blocks_if_scheduled is not None + + fits_peft = True + if scheduler.peft_cache_manager is not None: + max_peft_pages = scheduler._get_max_peft_pages() + fits_peft = (num_required_peft_pages + num_scheduled_peft_pages + <= max_peft_pages) + + if fits_kv_cache and fits_peft: + scheduled_blocks_manager.update_scheduled_blocks( + blocks_if_scheduled) + logger.debug( + f"MaxUtilizationScheduler: scheduled peft pages: {num_required_peft_pages}" + ) + scheduled_requests.append(req) + if is_new_task: + seen_task_ids.add(lora_task_id) + return True - # Revert victim's usage in the map - self._req_revert_map(victim_req, - current_free_blocks_map, - is_guaranteed_no_evict=False) + return False - # Retry current req (do NOT increment idx) - continue - else: - # No victim found, and current request doesn't fit. Stop. - break - return self._classify_output(active_requests, scheduled_requests, - paused_requests) +class NoEvictScheduledBlocksManager: + """ + Python equivalent of C++ kv_cache_manager::NoEvictScheduledBlocksManager. + Tracks available blocks per window size for GUARANTEED_NO_EVICT scheduling. - def _schedule_guaranteed_no_evict(self, active_requests: RequestList): - scheduled_requests: RequestList = [] - pending_disagg_requests: RequestList = [] - pending_context_requests: RequestList = [] + Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:29-62 + """ - # KV Cache Resource Tracking - available_blocks_map = self._get_initial_available_blocks_map() + def __init__(self, kv_cache_manager): + """ + Initialize with free blocks from KVCacheManager. + C++ equivalent: mAvailableBlocks = mKvCacheManager.getBlockManager().getNumFreeBlocksPerWindowSize() + """ + self.kv_cache_manager = kv_cache_manager + stats = kv_cache_manager.impl.get_kv_cache_stats() + self.available_blocks: Dict[int, int] = dict( + stats.num_free_blocks_per_window_size) - # PEFT Resource Tracking - claimed_peft_pages = 0 - uniq_task_ids: Set[int] = set() + def decrement_reserved_blocks(self, req: LlmRequest) -> None: + """ + Decrement available blocks by the blocks needed to complete this request. + C++ reference: scheduledBlocksManager.h:40-46 + """ + for window_size in self.available_blocks: + needed = self.kv_cache_manager.impl.get_remaining_blocks_to_completion( + req, window_size) + self.available_blocks[window_size] -= needed - # --- Pass 1: Running Requests --- - for request in active_requests: - req_state = request.state - is_disagg_init = ( - req_state == LlmRequestState.DISAGG_GENERATION_INIT) + def enough_available_blocks(self, req: LlmRequest) -> bool: + """ + Check if there are enough available blocks for this request across all window sizes. + C++ reference: scheduledBlocksManager.h:48-57 + """ + return all( + self.kv_cache_manager.impl.get_remaining_blocks_to_completion( + req, ws) <= avail + for ws, avail in self.available_blocks.items()) - if not is_disagg_init and ( - req_state.value < self.no_schedule_until_state.value - or req_state.value >= self.no_schedule_after_state.value): - continue - if len(scheduled_requests) >= self.max_num_requests: - if is_disagg_init: - pending_disagg_requests.append(request) - else: - pending_context_requests.append(request) - continue +class MaxUtilizationScheduledBlocksManager: + """ + Python equivalent of C++ kv_cache_manager::MaxUtilizationScheduledBlocksManager. + Tracks scheduled blocks per window size for MAX_UTILIZATION scheduling. - if (req_state == LlmRequestState.GENERATION_IN_PROGRESS - or req_state == LlmRequestState.GENERATION_TO_COMPLETE): - - # 1. Update KV Cache Map (Unconditional) - self._req_force_update_map(request, - available_blocks_map, - is_guaranteed_no_evict=True) - - # 2. Update PEFT Usage - # C++: if (isNewTask) claimedPeftPages += determineNumPages(req); - if self.peft_cache_manager and request.lora_task_id is not None: - task_id = request.lora_task_id - if task_id not in uniq_task_ids: - # Binding check: determine_num_pages - pages = self.peft_cache_manager_cpp.determine_num_pages( - request) - claimed_peft_pages += pages - uniq_task_ids.add(task_id) - - scheduled_requests.append(request) - else: - if is_disagg_init: - pending_disagg_requests.append(request) - else: - pending_context_requests.append(request) + Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:64-117 + """ - # --- Pass 2: New / Context Requests --- - available_peft_pages = self.max_peft_pages - claimed_peft_pages - all_pending = pending_disagg_requests + pending_context_requests + def __init__(self, kv_cache_manager, two_steps_look_ahead: bool): + """ + Initialize scheduled blocks count per window size. + C++ equivalent: iterate windowSizes and set mNumScheduledBlocks[windowSize] = 0 + """ + self.kv_cache_manager = kv_cache_manager + self.two_steps_look_ahead = two_steps_look_ahead + window_sizes = set(kv_cache_manager.max_attention_window_vec) + self.num_scheduled_blocks: Dict[int, int] = { + ws: 0 + for ws in window_sizes + } + + def prepare_blocks_if_schedulable( + self, req: LlmRequest) -> Optional[Dict[int, int]]: + """ + Check if request can be scheduled and return new block counts if so. + Returns None if request cannot fit. + C++ reference: scheduledBlocksManager.h:80-100 + """ + blocks_if_scheduled = {} + for window_size, num_scheduled in self.num_scheduled_blocks.items(): + required = self.kv_cache_manager.impl.get_needed_blocks_one_step( + req, self.two_steps_look_ahead, window_size) + logger.debug( + f"MaxUtilizationScheduler: request ID {req.request_id} " + f"required blocks {required} for {window_size} window size") + scheduled_total = num_scheduled + required + has_free = self.kv_cache_manager.impl.scheduling_has_free_blocks( + scheduled_total, window_size) + if not has_free: + return None + blocks_if_scheduled[window_size] = scheduled_total + return blocks_if_scheduled + + def update_scheduled_blocks(self, blocks: Dict[int, int]) -> None: + """ + Update the scheduled blocks after successfully scheduling a request. + C++ reference: scheduledBlocksManager.h:102-110 + """ + assert len(blocks) == len(self.num_scheduled_blocks), \ + f"Block count mismatch: {len(blocks)} vs {len(self.num_scheduled_blocks)}" + for window_size, blocks_if_scheduled in blocks.items(): + logger.debug( + f"MaxUtilizationScheduler: scheduled blocks {blocks_if_scheduled} " + f"for window size {window_size}") + self.num_scheduled_blocks[window_size] = blocks_if_scheduled - for request in all_pending: - if len(scheduled_requests) >= self.max_num_requests: - break - # 1. Check PEFT Capacity - needed_peft_pages = 0 - is_new_task = False - task_id = None +class PyCapacityScheduler: + """ + Python implementation of the C++ CapacityScheduler. + Aligned 1:1 with C++ logic in cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp. + Supports Multiple Window Sizes (VSWA), block reuse optimization, and all policies. - if self.peft_cache_manager and request.lora_task_id is not None: - task_id = request.lora_task_id - is_new_task = (task_id not in uniq_task_ids) - if is_new_task: - needed_peft_pages = self.peft_cache_manager_cpp.determine_num_pages( - request) - if needed_peft_pages > available_peft_pages: - # Not enough PEFT memory - break # Head-of-line blocking - - # 2. Check KV Cache Capacity - if not self._req_check_and_update_map( - request, available_blocks_map, is_guaranteed_no_evict=True): - # Not enough KV blocks - break + Policies: + - MaxRequestsScheduler: No KV cache manager, simple request count limit + - GuaranteedNoEvictScheduler: Reserve blocks for completion, no eviction + - StaticBatchScheduler: Only schedule when no requests are active + - MaxUtilizationScheduler: Maximize utilization, may pause requests - # 3. Commit Schedule - scheduled_requests.append(request) + Reference: cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h + """ - # Commit PEFT usage - if is_new_task: - available_peft_pages -= needed_peft_pages - uniq_task_ids.add(task_id) + def __init__( + self, + max_num_requests: int, + kv_cache_manager=None, + peft_cache_manager=None, + scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. + GUARANTEED_NO_EVICT, + cross_kv_cache_manager=None, + two_step_lookahead: bool = False, + no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, + no_schedule_after_state: LlmRequestState = LlmRequestState. + GENERATION_COMPLETE, + ): + """ + Initialize the capacity scheduler. + + Args: + max_num_requests: Maximum number of requests to schedule + kv_cache_manager: KV cache manager (None for MaxRequestsScheduler) + peft_cache_manager: PEFT/LoRA cache manager (optional) + scheduler_policy: Scheduling policy + cross_kv_cache_manager: Cross-attention KV cache manager for encoder-decoder + two_step_lookahead: Enable two-step lookahead for MAX_UTILIZATION + no_schedule_until_state: Don't schedule until this state is reached + no_schedule_after_state: Don't schedule after this state is reached + """ + self.max_num_requests = max_num_requests + self.kv_cache_manager = kv_cache_manager + self.peft_cache_manager = peft_cache_manager + self.cross_kv_cache_manager = cross_kv_cache_manager + self.scheduler_policy = scheduler_policy + self.two_step_lookahead = two_step_lookahead + self.no_schedule_until_state = no_schedule_until_state + self.no_schedule_after_state = no_schedule_after_state - return self._classify_output(active_requests, scheduled_requests, []) + # Initialize the appropriate policy + self._policy = self._create_policy() - def _classify_output(self, active_requests, scheduled_requests, - explicit_paused_requests): - fitting_requests = [] - fitting_disagg_gen_init = [] - paused_requests = list(explicit_paused_requests) + def _create_policy(self) -> SchedulerPolicyBase: + """Create the appropriate policy based on configuration.""" + if self.kv_cache_manager is None: + return MaxRequestsPolicy() + elif self.scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION: + return MaxUtilizationPolicy() + elif self.scheduler_policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: + return GuaranteedNoEvictPolicy(static_batch=False) + elif self.scheduler_policy == CapacitySchedulerPolicy.STATIC_BATCH: + return GuaranteedNoEvictPolicy(static_batch=True) + else: + raise ValueError( + f"Unsupported scheduler policy: {self.scheduler_policy}") - scheduled_ids = set(r.request_id for r in scheduled_requests) - paused_ids = set(r.request_id for r in paused_requests) + def _has_reached_state(self, req: LlmRequest, + target_state: LlmRequestState) -> bool: + """Check if request has reached the target state.""" + # C++ equivalent: req->hasReachedState(state) + # States are ordered: ENCODER_INIT(1) < CONTEXT_INIT(2) < GENERATION_IN_PROGRESS(3) < ... + return req.state.value >= target_state.value + + def _can_be_scheduled(self, req: LlmRequest) -> bool: + """ + Check if request is within the schedulable state range. + Returns True if request has reached no_schedule_until_state + but has not yet reached no_schedule_after_state. + """ + if not self._has_reached_state(req, self.no_schedule_until_state): + return False + if self._has_reached_state(req, self.no_schedule_after_state): + return False + return True + + def _is_skipping_relevant(self) -> bool: + """ + Check if block reuse skip optimization is relevant. + Disabled for VSWA (Variable Sliding Window Attention). + C++ reference: capacityScheduler.cpp:207-208, 348 + """ + if self.kv_cache_manager is None: + return False + if self.kv_cache_manager.is_vswa: + return False + if (self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.is_vswa): + return False + return True + + def _prefill_contributed_blocks( + self, active_requests: RequestList) -> Tuple[Set, Set]: + """ + Collect blocks contributed by chunked context requests already executing. + These blocks can be reused by later requests. + + C++ reference: capacityScheduler.cpp:34-68 (prefillWithChunkedContextsAlreadyExecuting) + """ + newly_contributed_context_blocks: Set = set() + newly_contributed_cross_context_blocks: Set = set() + + if self.kv_cache_manager is None: + return newly_contributed_context_blocks, newly_contributed_cross_context_blocks + + enable_block_reuse = self.kv_cache_manager.enable_block_reuse + cross_enable_reuse = (self.cross_kv_cache_manager is not None and + self.cross_kv_cache_manager.enable_block_reuse) for req in active_requests: - if (req.request_id not in scheduled_ids - and req.request_id not in paused_ids - and req.state == LlmRequestState.GENERATION_IN_PROGRESS): - paused_requests.append(req) + # Check: isContextInitState() && !isFirstContextChunk() + if req.is_context_init_state and not req.is_first_context_chunk: + # Chunked context request already executing + if enable_block_reuse: + unique_tokens = req.get_unique_tokens(0) + block_key = self.kv_cache_manager.impl.find_new_context_block( + unique_tokens, req) + if block_key is not None: + newly_contributed_context_blocks.add(block_key) + + if cross_enable_reuse: + encoder_unique_tokens = req.get_encoder_unique_tokens() + if encoder_unique_tokens is not None: + block_key = self.cross_kv_cache_manager.impl.find_new_context_block( + encoder_unique_tokens, req) + if block_key is not None: + newly_contributed_cross_context_blocks.add( + block_key) + + return newly_contributed_context_blocks, newly_contributed_cross_context_blocks + + def _one_manager_beneficial_to_skip(self, kv_cache_manager, unique_tokens, + req: LlmRequest, + newly_contributed_blocks: Set) -> bool: + """ + Check if skipping is beneficial for one KV cache manager. + C++ reference: capacityScheduler.cpp:70-92 (oneManagerBeneficialToSkip) + """ + new_context_block = kv_cache_manager.impl.find_new_context_block( + unique_tokens, req) + if new_context_block is not None: + if new_context_block in newly_contributed_blocks: + return True + newly_contributed_blocks.add(new_context_block) + return False + + def _beneficial_to_skip( + self, req: LlmRequest, newly_contributed_context_blocks: Set, + newly_contributed_cross_context_blocks: Set) -> bool: + """ + Check if it's beneficial to skip this request. + A request should be skipped if it can reuse blocks contributed by + already scheduled context requests. + + C++ reference: capacityScheduler.cpp:97-123 (beneficialToSkip) + """ + if not (req.is_context_init_state and req.is_first_context_chunk): + return False + + if (self.kv_cache_manager is not None + and self.kv_cache_manager.enable_block_reuse): + unique_tokens = req.get_unique_tokens(0) + if self._one_manager_beneficial_to_skip( + self.kv_cache_manager, unique_tokens, req, + newly_contributed_context_blocks): + return True + + if (self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.enable_block_reuse): + encoder_unique_tokens = req.get_encoder_unique_tokens() + if encoder_unique_tokens is not None: + if self._one_manager_beneficial_to_skip( + self.cross_kv_cache_manager, encoder_unique_tokens, req, + newly_contributed_cross_context_blocks): + return True + + return False + + def _get_max_peft_pages(self) -> int: + """Get maximum PEFT cache pages.""" + if self.peft_cache_manager is None: + return 2**31 - 1 # INT_MAX equivalent + return self.peft_cache_manager.get_max_device_pages() + + def _get_peft_pages_for_request(self, req: LlmRequest) -> int: + """Get PEFT pages needed for a request.""" + if self.peft_cache_manager is None: + return 0 + return self.peft_cache_manager.determine_num_pages(req) + + def _get_peft_task_info( + self, req: LlmRequest, + seen_task_ids: Set[int]) -> Tuple[Optional[int], bool, int]: + """ + Get PEFT task information for a request. + Returns (lora_task_id, is_new_task, required_pages). + """ + lora_task_id = getattr(req, 'lora_task_id', None) + is_new_task = lora_task_id is not None and lora_task_id not in seen_task_ids + required_pages = self._get_peft_pages_for_request( + req) if is_new_task else 0 + return lora_task_id, is_new_task, required_pages + + def _can_be_scheduled_with_disagg_exception(self, req: LlmRequest) -> bool: + """ + Check if request can be scheduled, with exception for disagg generation init state. + Disagg generation init requests bypass the normal state gating. + """ + if req.is_disagg_generation_init_state: + return True + return self._can_be_scheduled(req) + + def schedule_request( + self, active_requests: RequestList + ) -> Tuple[RequestList, RequestList, RequestList]: + """ + Schedule requests based on the configured policy. + + Args: + active_requests: List of active requests to consider + + Returns: + Tuple of (fitting_requests, fitting_disagg_gen_init_requests, paused_requests) + + C++ reference: capacityScheduler.cpp:488-539 (CapacityScheduler::operator()) + """ + scheduled, paused = self._policy.schedule(self, active_requests) + + fitting_requests, fitting_disagg_gen_init_requests = self._classify_output( + scheduled) + logger.debug( + f"[Summary] Capacity scheduler allows {len(fitting_requests)} requests, " + f"pauses {len(paused)} requests") + + return fitting_requests, fitting_disagg_gen_init_requests, paused + + def _classify_output( + self, + scheduled_requests: RequestList) -> Tuple[RequestList, RequestList]: + """ + Separate scheduled requests into normal requests and disagg gen init requests. + C++ reference: capacityScheduler.cpp:522-534 + """ + fitting_requests: RequestList = [] + fitting_disagg_gen_init_requests: RequestList = [] for req in scheduled_requests: - if req.state == LlmRequestState.DISAGG_GENERATION_INIT: - fitting_disagg_gen_init.append(req) + if req.is_disagg_generation_init_state: + fitting_disagg_gen_init_requests.append(req) else: fitting_requests.append(req) - - return fitting_requests, fitting_disagg_gen_init, paused_requests + return fitting_requests, fitting_disagg_gen_init_requests class SimpleUnifiedScheduler(RequestScheduler): @@ -855,13 +1243,18 @@ def __init__( peft_cache_manager, scheduler_policy: CapacitySchedulerPolicy, ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None, + cross_kv_cache_manager=None, + two_step_lookahead: bool = False, ): # 1. Initialize Python Capacity Scheduler + # Now fully aligned with C++ CapacityScheduler self.capacity_scheduler = PyCapacityScheduler( max_num_requests=max_batch_size, kv_cache_manager=kv_cache_manager, peft_cache_manager=peft_cache_manager, - scheduler_policy=scheduler_policy) + scheduler_policy=scheduler_policy, + cross_kv_cache_manager=cross_kv_cache_manager, + two_step_lookahead=two_step_lookahead) # 2. Initialize Python MicroBatch Scheduler py_chunk_config = None From 80b725392cca47f687467178b96bc0ccde24013e Mon Sep 17 00:00:00 2001 From: Lanyu Liao Date: Wed, 24 Dec 2025 19:45:39 -0800 Subject: [PATCH 23/25] fix part of CI failues by exposing more c++ api Signed-off-by: Lanyu Liao --- .../nanobind/batch_manager/bindings.cpp | 1 + .../nanobind/batch_manager/kvCacheManager.cpp | 4 +- .../pybind/batch_manager/bindings.cpp | 1 + .../pybind/batch_manager/kvCacheManager.cpp | 4 +- tensorrt_llm/_torch/pyexecutor/_util.py | 6 ++- tensorrt_llm/_torch/pyexecutor/scheduler.py | 38 +++++++++---------- 6 files changed, 31 insertions(+), 23 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index b75f8e8bd69..2e19dd792d0 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -174,6 +174,7 @@ void initBindings(nb::module_& m) .def_prop_ro("is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete) .def_prop_ro( "is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress) + .def_prop_ro("is_encoder_init_state", &GenLlmReq::isEncoderInitState) .def_prop_ro("is_context_init_state", &GenLlmReq::isContextInitState) .def_prop_ro("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState) .def_prop_ro("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 2fb6ad95a9e..5392cc9ac5a 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -531,7 +531,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) "scheduling_has_free_blocks", [](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize) { return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); }, - nb::arg("num_required"), nb::arg("window_size"), nb::call_guard()); + nb::arg("num_required"), nb::arg("window_size"), nb::call_guard()) + .def_prop_ro( + "is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); }); } void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 510571613e7..40c8a5c89a4 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -180,6 +180,7 @@ void initBindings(pybind11::module_& m) "is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete) .def_property_readonly( "is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress) + .def_property_readonly("is_encoder_init_state", &GenLlmReq::isEncoderInitState) .def_property_readonly("is_context_init_state", &GenLlmReq::isContextInitState) .def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState) .def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 2ef12236795..5280ce497c5 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -526,7 +526,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) "scheduling_has_free_blocks", [](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize) { return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); }, - py::arg("num_required"), py::arg("window_size"), py::call_guard()); + py::arg("num_required"), py::arg("window_size"), py::call_guard()) + .def_property_readonly( + "is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); }); } void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 11f6ac4da0d..e4a04eebd3e 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -811,8 +811,10 @@ def create_py_executor_instance( scheduler = SimpleUnifiedScheduler( max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, - kv_cache_manager=kv_cache_manager, - peft_cache_manager=peft_cache_manager, + kv_cache_manager=kv_cache_manager.impl + if kv_cache_manager is not None else None, + peft_cache_manager=peft_cache_manager.impl + if peft_cache_manager is not None else None, scheduler_policy=scheduler_config.capacity_scheduler_policy, ctx_chunk_config=ctx_chunk_config, two_step_lookahead=mapping.has_pp()) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index d4533c3a70c..4efa06d2ed5 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -465,12 +465,13 @@ def _sort_requests(self, context_requests: RequestList, 2. Sort all requests by lora task id for performance. """ - def get_lora_task_id(req: LlmRequest) -> int: - # Return lora_task_id or a large value if not set + def get_lora_task_id(req: LlmRequest): + # C++ uses std::optional comparison where nullopt < any_value + # So requests without LoRA (nullopt) should come first lora_id = getattr(req, 'lora_task_id', None) if lora_id is None: - return float('inf') - return lora_id + return (0, 0) # (has_value=False, value=0) - comes first + return (1, lora_id) # (has_value=True, value) - sorted by value if chunks_present: # Partition: non-last-chunk first, last-chunk at end @@ -761,7 +762,7 @@ class MaxUtilizationPolicy(SchedulerPolicyBase): def schedule( self, scheduler: 'PyCapacityScheduler', active_requests: RequestList) -> Tuple[RequestList, RequestList]: - scheduler.kv_cache_manager.impl.start_scheduling() + scheduler.kv_cache_manager.start_scheduling() skipping_is_relevant = scheduler._is_skipping_relevant() @@ -824,7 +825,7 @@ def is_started_request(req: LlmRequest) -> bool: if last_started_idx is not None: paused_req = requests_list[last_started_idx] - scheduler.kv_cache_manager.impl.scheduling_remove_sequence( + scheduler.kv_cache_manager.scheduling_remove_sequence( paused_req.py_request_id) paused_requests.append(paused_req) logger.debug( @@ -887,7 +888,7 @@ def __init__(self, kv_cache_manager): C++ equivalent: mAvailableBlocks = mKvCacheManager.getBlockManager().getNumFreeBlocksPerWindowSize() """ self.kv_cache_manager = kv_cache_manager - stats = kv_cache_manager.impl.get_kv_cache_stats() + stats = kv_cache_manager.get_kv_cache_stats() self.available_blocks: Dict[int, int] = dict( stats.num_free_blocks_per_window_size) @@ -897,7 +898,7 @@ def decrement_reserved_blocks(self, req: LlmRequest) -> None: C++ reference: scheduledBlocksManager.h:40-46 """ for window_size in self.available_blocks: - needed = self.kv_cache_manager.impl.get_remaining_blocks_to_completion( + needed = self.kv_cache_manager.get_remaining_blocks_to_completion( req, window_size) self.available_blocks[window_size] -= needed @@ -907,9 +908,8 @@ def enough_available_blocks(self, req: LlmRequest) -> bool: C++ reference: scheduledBlocksManager.h:48-57 """ return all( - self.kv_cache_manager.impl.get_remaining_blocks_to_completion( - req, ws) <= avail - for ws, avail in self.available_blocks.items()) + self.kv_cache_manager.get_remaining_blocks_to_completion(req, ws) <= + avail for ws, avail in self.available_blocks.items()) class MaxUtilizationScheduledBlocksManager: @@ -942,13 +942,13 @@ def prepare_blocks_if_schedulable( """ blocks_if_scheduled = {} for window_size, num_scheduled in self.num_scheduled_blocks.items(): - required = self.kv_cache_manager.impl.get_needed_blocks_one_step( + required = self.kv_cache_manager.get_needed_blocks_one_step( req, self.two_steps_look_ahead, window_size) logger.debug( f"MaxUtilizationScheduler: request ID {req.request_id} " f"required blocks {required} for {window_size} window size") scheduled_total = num_scheduled + required - has_free = self.kv_cache_manager.impl.scheduling_has_free_blocks( + has_free = self.kv_cache_manager.scheduling_has_free_blocks( scheduled_total, window_size) if not has_free: return None @@ -1063,10 +1063,10 @@ def _is_skipping_relevant(self) -> bool: """ if self.kv_cache_manager is None: return False - if self.kv_cache_manager.is_vswa: + if self.kv_cache_manager.is_variable_window: return False if (self.cross_kv_cache_manager is not None - and self.cross_kv_cache_manager.is_vswa): + and self.cross_kv_cache_manager.is_variable_window): return False return True @@ -1094,7 +1094,7 @@ def _prefill_contributed_blocks( # Chunked context request already executing if enable_block_reuse: unique_tokens = req.get_unique_tokens(0) - block_key = self.kv_cache_manager.impl.find_new_context_block( + block_key = self.kv_cache_manager.find_new_context_block( unique_tokens, req) if block_key is not None: newly_contributed_context_blocks.add(block_key) @@ -1102,7 +1102,7 @@ def _prefill_contributed_blocks( if cross_enable_reuse: encoder_unique_tokens = req.get_encoder_unique_tokens() if encoder_unique_tokens is not None: - block_key = self.cross_kv_cache_manager.impl.find_new_context_block( + block_key = self.cross_kv_cache_manager.find_new_context_block( encoder_unique_tokens, req) if block_key is not None: newly_contributed_cross_context_blocks.add( @@ -1117,7 +1117,7 @@ def _one_manager_beneficial_to_skip(self, kv_cache_manager, unique_tokens, Check if skipping is beneficial for one KV cache manager. C++ reference: capacityScheduler.cpp:70-92 (oneManagerBeneficialToSkip) """ - new_context_block = kv_cache_manager.impl.find_new_context_block( + new_context_block = kv_cache_manager.find_new_context_block( unique_tokens, req) if new_context_block is not None: if new_context_block in newly_contributed_blocks: @@ -1161,7 +1161,7 @@ def _get_max_peft_pages(self) -> int: """Get maximum PEFT cache pages.""" if self.peft_cache_manager is None: return 2**31 - 1 # INT_MAX equivalent - return self.peft_cache_manager.get_max_device_pages() + return self.peft_cache_manager.max_device_pages def _get_peft_pages_for_request(self, req: LlmRequest) -> int: """Get PEFT pages needed for a request.""" From 7a265286293a66204dbdc418fb8063a39cca7888 Mon Sep 17 00:00:00 2001 From: Lance Liao <108499334+lancelly@users.noreply.github.com> Date: Thu, 25 Dec 2025 22:17:26 -0800 Subject: [PATCH 24/25] fix scheduler capacity for disagg gen init reqs Signed-off-by: Lance Liao <108499334+lancelly@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 3 ++- tensorrt_llm/_torch/pyexecutor/scheduler.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 138773cdb16..f856a5ef7d5 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -848,7 +848,8 @@ def create_py_executor_instance( if peft_cache_manager is not None else None, scheduler_policy=scheduler_config.capacity_scheduler_policy, ctx_chunk_config=ctx_chunk_config, - two_step_lookahead=mapping.has_pp()) + two_step_lookahead=mapping.has_pp(), + scheduler_capacity=scheduler_capacity) else: capacity_scheduler = BindCapacityScheduler( scheduler_capacity, diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 4efa06d2ed5..e45231c1d5c 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -1245,11 +1245,16 @@ def __init__( ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None, cross_kv_cache_manager=None, two_step_lookahead: bool = False, + scheduler_capacity: Optional[int] = None, ): + # Use scheduler_capacity if provided, otherwise fall back to max_batch_size + # scheduler_capacity may differ from max_batch_size (e.g., adjusted for attention_dp + disagg) + capacity = scheduler_capacity if scheduler_capacity is not None else max_batch_size + # 1. Initialize Python Capacity Scheduler # Now fully aligned with C++ CapacityScheduler self.capacity_scheduler = PyCapacityScheduler( - max_num_requests=max_batch_size, + max_num_requests=capacity, kv_cache_manager=kv_cache_manager, peft_cache_manager=peft_cache_manager, scheduler_policy=scheduler_policy, From 4b65790fbf59d2d183c747d9b4c60b69c89971bb Mon Sep 17 00:00:00 2001 From: Lanyu Liao Date: Fri, 26 Dec 2025 21:47:52 -0800 Subject: [PATCH 25/25] use cpp scheduler by default for now Signed-off-by: Lanyu Liao --- tensorrt_llm/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 2f4ff469e55..cea56431b77 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -17,7 +17,6 @@ # Disable UCC to WAR allgather issue before NGC PyTorch 25.12 upgrade. os.environ["OMPI_MCA_coll_ucc_enable"] = "0" -os.environ["TLLM_USE_PYTHON_SCHEDULER"] = "1" def _add_trt_llm_dll_directory():