@@ -106,6 +106,7 @@ class BatchState:
106106class BatchStatePP (BatchState ):
107107 microbatch_id : int = - 1
108108 scheduled_ctx_reqs : list [LlmRequest ] = None
109+ finished_ctx_reqs : list [LlmRequest ] = None
109110
110111
111112class PyExecutor :
@@ -232,6 +233,8 @@ def __init__(self,
232233 | None ] = [None ] * self .num_micro_batches
233234 self .send_handles = [None ] * self .num_micro_batches
234235
236+ # Set of request IDs that are currently in flight across all micro batches.
237+ # The scheduler will avoid scheduling requests that are already in flight.
235238 self .inflight_req_ids = ReqIdsSet ()
236239
237240 # During warmup, we don't enable the profiler
@@ -694,7 +697,7 @@ def get_queued_req_stats(request_id: int) -> RequestStats:
694697 return req_stats
695698
696699 def _update_iter_stats (self , stats , iter_latency_ms , num_completed_requests ,
697- scheduled_batch ) -> IterationStats :
700+ scheduled_batch , micro_batch_id ) -> IterationStats :
698701 stats .iter_latency_ms = iter_latency_ms
699702
700703 stats .num_queued_requests = self .executor_request_queue .get_request_queue_size (
@@ -735,7 +738,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
735738 stats .inflight_batching_stats .num_paused_requests = len (
736739 scheduled_batch .paused_requests )
737740 stats .inflight_batching_stats .avg_num_decoded_tokens_per_iter = 0
738- stats .inflight_batching_stats .micro_batch_id = 0
741+ stats .inflight_batching_stats .micro_batch_id = micro_batch_id
739742 if stats .specdec_stats is not None :
740743 stats .specdec_stats .draft_overhead = 0.0 if iter_latency_ms <= 0.0 else float (
741744 stats .specdec_stats .iter_latency_ms ) / float (iter_latency_ms )
@@ -748,9 +751,13 @@ def _append_iter_stats(self,
748751 with self .stats_lock :
749752 self .stats .append ((stats , req_stats ))
750753
751- def _process_iter_stats (self , finished_requests : list [LlmRequest ],
752- active_requests : List [LlmRequest ],
753- batch_state : BatchState ):
754+ def _process_iter_stats (
755+ self ,
756+ finished_requests : list [LlmRequest ],
757+ active_requests : List [LlmRequest ],
758+ batch_state : BatchState ,
759+ micro_batch_id : int = 0 ,
760+ ):
754761 iter_end_time = time .time ()
755762 iter_latency_ms = (iter_end_time - batch_state .iter_start_time ) * 1e3
756763 if batch_state .iter_stats is None :
@@ -763,9 +770,10 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest],
763770 and self .enable_iter_perf_stats ) else None
764771
765772 self ._append_iter_stats (
766- self ._update_iter_stats (
767- batch_state .iter_stats , iter_latency_ms , len (finished_requests ),
768- batch_state .sample_state .scheduled_requests ), req_stats )
773+ self ._update_iter_stats (batch_state .iter_stats , iter_latency_ms ,
774+ len (finished_requests ),
775+ batch_state .sample_state .scheduled_requests ,
776+ micro_batch_id ), req_stats )
769777
770778 def _executor_loop_cleanup (self ):
771779
@@ -825,6 +833,7 @@ def _executor_loop_pp(self):
825833 self .num_scheduled_requests = scheduled_batch .batch_size
826834
827835 logger .debug (
836+ f'iteration { self .iter_counter } , microbatch { microbatch_id } , '
828837 f'has { len (self .active_requests )} active_requests, '
829838 f'scheduled { len (scheduled_batch .context_requests )} context requests and '
830839 f'{ len (scheduled_batch .generation_requests )} generation requests'
@@ -833,9 +842,13 @@ def _executor_loop_pp(self):
833842 can_queue = self ._can_queue (scheduled_batch )
834843
835844 if not can_queue :
845+ logger .debug (
846+ f"microbatch { microbatch_id } cannot be queued, skipping"
847+ )
836848 self .micro_batches [microbatch_id ] = None
837849 else :
838- self ._add_inflight_ids (scheduled_batch )
850+ logger .debug (f"microbatch { microbatch_id } can be queued" )
851+ finished_ctx_reqs = self ._add_inflight_ids (scheduled_batch )
839852
840853 if self .kv_cache_transceiver :
841854 # For generation requests which have completed KV cache transfer
@@ -895,6 +908,7 @@ def _executor_loop_pp(self):
895908 iter_stats = iter_stats ,
896909 microbatch_id = microbatch_id ,
897910 scheduled_ctx_reqs = scheduled_batch .context_requests ,
911+ finished_ctx_reqs = finished_ctx_reqs ,
898912 )
899913
900914 self .micro_batches [microbatch_id ] = batch_state
@@ -945,6 +959,8 @@ def _executor_loop_pp(self):
945959 finished_requests = []
946960 if previous_batch is not None :
947961 with torch .cuda .nvtx .range ("_handle_previous_batch_pp" ):
962+ sample_state = previous_batch .sample_state
963+ sample_state .scheduled_requests .context_requests = previous_batch .finished_ctx_reqs
948964 self ._update_requests (previous_batch .sample_state )
949965
950966 if self .block_reuse_enabled and not self .kv_cache_manager .is_vswa and self .kv_cache_transceiver :
@@ -976,7 +992,8 @@ def _executor_loop_pp(self):
976992 self .resource_manager .update_resources (
977993 previous_scheduled_batch , attn_metadata ,
978994 kv_cache_dtype_byte_size )
979- self ._remove_inflight_ids (previous_scheduled_batch )
995+
996+ self ._remove_inflight_ids (previous_batch )
980997
981998 self .wait_on_pp_send_handles (prev_microbatch_id )
982999 self .micro_batches [prev_microbatch_id ] = None
@@ -993,9 +1010,11 @@ def _executor_loop_pp(self):
9931010 microbatch_id = (microbatch_id + 1 ) % self .num_micro_batches
9941011
9951012 if self .enable_iter_perf_stats and previous_batch is not None :
1013+ sample_state = previous_batch .sample_state
1014+ sample_state .scheduled_requests .context_requests = previous_batch .scheduled_ctx_reqs
9961015 self ._process_iter_stats (finished_requests ,
9971016 self .active_requests ,
998- previous_batch )
1017+ previous_batch , microbatch_id )
9991018
10001019 self .iter_counter += 1
10011020
@@ -2481,13 +2500,43 @@ def _pause_requests(self, requests_to_pause):
24812500 self ._terminate_request (req )
24822501
24832502 def _add_inflight_ids (self , scheduled_requests ):
2484- """Add reqids of current requests to self.inflight_req_ids."""
2485- for req in scheduled_requests .all_requests ():
2503+ """Add request IDs of current requests to self.inflight_req_ids.
2504+
2505+ Non‑final context chunks are not added to the inflight set, so the scheduler can keep scheduling further
2506+ context chunks while earlier ones are in the PP pipeline. Only context requests that finish context phase
2507+ are inserted into the inflight set and collected into finished_ctx_reqs.
2508+ All generation requests are still inserted into the inflight set.
2509+ """
2510+ finished_ctx_reqs = []
2511+ for req in scheduled_requests .context_requests :
2512+ if req .is_last_context_chunk :
2513+ logger .debug (
2514+ f"Context request with ID { req .request_id } added to DECODER model inflight set"
2515+ )
2516+ self .inflight_req_ids .insert (req .request_id )
2517+ finished_ctx_reqs .append (req )
2518+ for req in scheduled_requests .generation_requests :
2519+ logger .debug (
2520+ f"Generation request with ID { req .request_id } added to DECODER model inflight set"
2521+ )
24862522 self .inflight_req_ids .insert (req .request_id )
2523+ return finished_ctx_reqs
2524+
2525+ def _remove_inflight_ids (self , batch_state : BatchStatePP ):
2526+ """Remove request IDs of current requests from self.inflight_req_ids.
24872527
2488- def _remove_inflight_ids (self , scheduled_requests ):
2489- """Remove reqids of current requests from self.inflight_req_ids."""
2490- for req in scheduled_requests .all_requests ():
2528+ Context IDs are erased from the inflight set using batch_state.finished_ctx_reqs.
2529+ Generation IDs are erased using batch_state.sample_state.scheduled_requests.generation_requests.
2530+ """
2531+ for req in batch_state .finished_ctx_reqs :
2532+ logger .debug (
2533+ f"Context request with ID { req .request_id } removed from DECODER model inflight set"
2534+ )
2535+ self .inflight_req_ids .erase (req .request_id )
2536+ for req in batch_state .sample_state .scheduled_requests .generation_requests :
2537+ logger .debug (
2538+ f"Generation request with ID { req .request_id } removed from DECODER model inflight set"
2539+ )
24912540 self .inflight_req_ids .erase (req .request_id )
24922541
24932542 def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ,
0 commit comments