@@ -1161,7 +1161,7 @@ def _executor_loop_pp(self):
11611161 f'{ len (scheduled_batch .generation_requests )} generation requests'
11621162 )
11631163
1164- can_queue = self ._can_queue (scheduled_batch )
1164+ can_queue , _ = self ._can_queue (scheduled_batch )
11651165 if not can_queue :
11661166 logger .debug (
11671167 f"microbatch { microbatch_id } cannot be queued, skipping"
@@ -1359,13 +1359,17 @@ def wait_on_pp_send_handles(self, microbatch_id):
13591359
13601360 def _can_queue (self , scheduled_batch ):
13611361
1362+ # can_queue_this_rank is for case that the batch is not empty on this rank, but empty on other ranks
1363+ # For bs == 1, we cannot pad dummy request to make the batch non-empty since it will cause the batch size to be 2.
1364+ # 1 for dummy request, 1 for the to complete but haven't updated request.
13621365 if self .enable_attention_dp :
13631366 tp_batch_sizes = self .dist .tp_allgather (scheduled_batch .batch_size )
13641367 can_queue = 0 not in tp_batch_sizes
1368+ can_queue_this_rank = scheduled_batch .batch_size > 0
13651369 else :
1366- can_queue = scheduled_batch .batch_size > 0
1370+ can_queue = can_queue_this_rank = scheduled_batch .batch_size > 0
13671371
1368- return can_queue
1372+ return can_queue , can_queue_this_rank
13691373
13701374 def _prepare_and_schedule_batch (self ):
13711375 new_requests = self ._fetch_and_activate_new_requests ()
@@ -1494,7 +1498,7 @@ def _executor_loop(self):
14941498
14951499 finished_requests = []
14961500
1497- can_queue = self ._can_queue (scheduled_batch )
1501+ can_queue , _ = self ._can_queue (scheduled_batch )
14981502 if can_queue :
14991503 if self .kv_cache_transceiver :
15001504 # For generation requests which have completed KV cache transfer
@@ -1509,7 +1513,7 @@ def _executor_loop(self):
15091513
15101514 # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
15111515 if self .kv_connector_manager :
1512- can_queue = self ._can_queue (scheduled_batch )
1516+ can_queue , _ = self ._can_queue (scheduled_batch )
15131517
15141518 if can_queue :
15151519 # init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers.
@@ -1711,7 +1715,8 @@ def _executor_loop_overlap(self):
17111715
17121716 self ._pause_requests (scheduled_batch .paused_requests )
17131717
1714- can_queue = self ._can_queue (scheduled_batch )
1718+ can_queue , can_queue_this_rank = self ._can_queue (
1719+ scheduled_batch )
17151720 if can_queue :
17161721 if self .kv_cache_transceiver :
17171722 # For generation requests which have completed KV cache transfer
@@ -1741,8 +1746,13 @@ def _executor_loop_overlap(self):
17411746
17421747 # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
17431748 if self .kv_connector_manager :
1744- can_queue = self ._can_queue (scheduled_batch )
1749+ can_queue , can_queue_this_rank = self ._can_queue (
1750+ scheduled_batch )
17451751
1752+ # If the batch is not empty on this rank, but empty on other ranks,
1753+ # we need to delay the update of the previous batch's sample state,
1754+ # and let the later iteration to update it.
1755+ should_process_previous_batch = can_queue or not can_queue_this_rank
17461756 if can_queue :
17471757
17481758 # The generation requests that are do not have batch_idx,
@@ -1792,10 +1802,10 @@ def _executor_loop_overlap(self):
17921802 scheduled_batch , previous_tensors_device ,
17931803 num_accepted_tokens_device )
17941804
1795- if self .previous_batch is not None :
1805+ if self .previous_batch is not None and should_process_previous_batch :
17961806 self ._update_requests (self .previous_batch .sample_state )
17971807
1798- if self .drafter is not None and self .use_spec_decode :
1808+ if self .drafter is not None and self .use_spec_decode and should_process_previous_batch :
17991809 # Cleanup previous draft resources used in the draft model
18001810 self .drafter .cleanup_previous_draft_resources ()
18011811
@@ -1822,8 +1832,10 @@ def _executor_loop_overlap(self):
18221832 ctx_transmission_reqs = self ._send_kv_async (
18231833 scheduled_batch .all_requests ())
18241834
1825- if self .previous_batch is not None :
1835+ if self .previous_batch is not None and should_process_previous_batch :
18261836 self ._process_previous_batch ()
1837+ else :
1838+ self ._enqueue_responses ([])
18271839
18281840 if can_queue :
18291841 if self .enable_iter_perf_stats :
@@ -1835,6 +1847,9 @@ def _executor_loop_overlap(self):
18351847 iter_start_time = iter_start_time ,
18361848 iter_stats = iter_stats ,
18371849 ctx_transmission_reqs = ctx_transmission_reqs )
1850+ elif not can_queue_this_rank :
1851+ # If the batch is empty on this rank, we need to clear the previous batch.
1852+ self .previous_batch = None
18381853
18391854 if self .kv_cache_transceiver and self .async_transfer_manager .has_any_inflight_requests (
18401855 ):
@@ -2194,10 +2209,10 @@ def _pad_attention_dp_dummy_request(self):
21942209 if self .kv_cache_transceiver is None :
21952210 num_active_request = len (self .active_requests )
21962211 else :
2197- num_active_request = sum ([
2198- 0 if req . is_disagg_generation_init_state
2199- or req .is_disagg_generation_transmission_in_progress else 1
2200- for req in self . active_requests
2212+ num_active_request = len ([
2213+ req for req in self . active_requests
2214+ if not ( req .is_disagg_generation_init_state
2215+ or req . is_disagg_generation_transmission_in_progress )
22012216 ])
22022217
22032218 if self .expected_num_active_requests - num_active_request > 0 and num_active_request == 0 :
@@ -2393,7 +2408,7 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0):
23932408
23942409 def _forward_step (
23952410 self ,
2396- scheduled_requests ,
2411+ scheduled_requests : ScheduledRequests ,
23972412 new_tensors_device : Optional [SampleStateTensors ] = None ,
23982413 num_accepted_tokens_device : Optional [torch .Tensor ] = None ):
23992414 ExpertStatistic .set_iter (self .iter_counter )
0 commit comments