99import time
1010import traceback
1111import weakref
12- from collections import namedtuple
12+ from collections import deque , namedtuple
1313from contextlib import contextmanager
1414from typing import Dict , List , Optional , Tuple , Union
1515
5757class RequestQueueItem :
5858 id : int
5959 request : Optional [ExecutorRequest ] = None
60+ is_canceled_request : bool = False
6061 query : Optional [list ] = None # only used in `StarAttention`
6162
63+ @property
6264 def is_shutdown_request (self ):
6365 return self .id == SHUTDOWN_REQUEST_ID
6466
67+ @property
68+ def is_normal_request (self ):
69+ return not (self .is_shutdown_request or self .is_canceled_request )
6570
66- def _get_from_request_queue (request_queue ,
67- timeout : Optional [datetime .timedelta ],
68- max_req_count : int ) -> List [RequestQueueItem ]:
71+
72+ def _get_from_request_queue (
73+ request_queue ,
74+ timeout : Optional [datetime .timedelta ]) -> List [RequestQueueItem ]:
6975 items = []
7076 timeout_secs = timeout .total_seconds () if timeout is not None else None
71- req_count = 0
7277 try :
7378 if request_queue .empty () and (timeout_secs is None or timeout_secs > 0 ):
7479 # if queue is empty and want to wait, wait
7580 items .append (request_queue .get (timeout = timeout_secs ))
7681 else :
7782 # if not empty or don't want to wait, just return all items in queue
78- while req_count < max_req_count :
83+ while True :
7984 queue_item = request_queue .get_nowait ()
8085 items .append (queue_item )
81- if not queue_item .is_shutdown_request ():
82- req_count += 1
8386 except queue .Empty :
8487 pass
8588 return items
8689
8790
91+ def _get_from_waiting_queue (
92+ waiting_queue : deque [RequestQueueItem ],
93+ max_req_count : int ,
94+ ) -> List [RequestQueueItem ]:
95+ """Safely extracts up to max_req_count items from a deque.
96+
97+ Args:
98+ waiting_queue: The queue to pop items from.
99+ max_req_count: Maximum items to retrieve. Returns empty list if <=0.
100+
101+ Returns:
102+ List of retrieved items (may be shorter than max_req_count if queue empties first).
103+ """
104+ # Edge case handling
105+ if max_req_count <= 0 : # Handles negative/zero counts
106+ return []
107+
108+ items = []
109+ req_count = 0
110+ while req_count < max_req_count and waiting_queue :
111+ items .append (waiting_queue .popleft ())
112+ req_count += 1
113+ return items
114+
115+
88116@functools .cache
89117def _load_iteration_indexes (env_var : str ):
90118 spans = os .environ .get (env_var , None )
@@ -182,6 +210,7 @@ def __init__(self,
182210 self .device_id = torch .cuda .current_device ()
183211 self .global_rank = global_mpi_rank ()
184212 self .request_queue : queue .Queue [RequestQueueItem ] = queue .Queue ()
213+ self .waiting_queue : deque [RequestQueueItem ] = deque ()
185214
186215 # profile config
187216 self .profile_start_iters , self .profile_stop_iters = _load_iteration_indexes (
@@ -251,7 +280,7 @@ def __init__(self,
251280 self .send_handles = [None ] * self .num_micro_batches
252281
253282 self .inflight_req_ids = ReqIdsSet ()
254- self .canceled_req_ids = ReqIdsSet ()
283+ self .canceled_req_ids = []
255284
256285 self .model_engine .warmup (self .resource_manager )
257286 if self .draft_model_engine is not None :
@@ -368,7 +397,12 @@ def cancel_request(self, id: int):
368397 Args:
369398 id (int): The request id for which to cancel the response
370399 """
371- self .canceled_req_ids .insert (id )
400+ try :
401+ self .enqueue_lock .acquire ()
402+ self .request_queue .put (
403+ RequestQueueItem (id , is_canceled_request = True ))
404+ finally :
405+ self .enqueue_lock .release ()
372406
373407 def shutdown (self ):
374408 """
@@ -454,6 +488,11 @@ def enqueue_request(self,
454488 def set_gather_responses (self , gather_all_responses ):
455489 self .gather_all_responses = gather_all_responses
456490
491+ @property
492+ def should_stop_processing (self ):
493+ return self .is_shutdown and len (self .active_requests ) == 0 and len (
494+ self .waiting_queue ) == 0
495+
457496 @contextmanager
458497 def _profiler (self ):
459498 it = - 1
@@ -710,12 +749,12 @@ def _executor_loop_pp(self):
710749 with self ._profiler () as profile_step :
711750 iter_start_time = time .time ()
712751 iter_stats = None
713- while not self .is_shutdown or len ( self . active_requests ) > 0 :
752+ while not self .should_stop_processing :
714753 profile_step ()
715754 if self .enable_iter_perf_stats :
716755 iter_start_time = time .time ()
717756 new_requests = self ._fetch_new_requests ()
718- if self .is_shutdown and len ( self . active_requests ) == 0 :
757+ if self .should_stop_processing :
719758 break
720759
721760 if self .enable_iter_perf_stats :
@@ -839,7 +878,7 @@ def _executor_loop_pp(self):
839878 if previous_batch is not None :
840879 with torch .cuda .nvtx .range ("_handle_previous_batch_pp" ):
841880 self ._update_requests (previous_batch .sample_state )
842- self ._handle_cancelled_requests ()
881+ self ._handle_canceled_requests ()
843882 finished_requests = self ._handle_responses ()
844883 previous_scheduled_batch = previous_batch .sample_state .scheduled_requests
845884 self .resource_manager .update_resources (
@@ -861,12 +900,12 @@ def _executor_loop(self):
861900 sample_state = None
862901 iter_start_time = time .time ()
863902 iter_stats = None
864- while not self .is_shutdown or len ( self . active_requests ) > 0 :
903+ while not self .should_stop_processing :
865904 profile_step ()
866905 if self .enable_iter_perf_stats :
867906 iter_start_time = time .time ()
868907 new_requests = self ._fetch_new_requests ()
869- if self .is_shutdown and len ( self . active_requests ) == 0 :
908+ if self .should_stop_processing :
870909 break
871910
872911 if self .kv_cache_transceiver :
@@ -950,7 +989,7 @@ def _executor_loop(self):
950989 for req in ctx_transmission_reqs :
951990 req .state = LlmRequestState .DISAGG_CONTEXT_TRANS_IN_PROGRESS
952991
953- self ._handle_cancelled_requests ()
992+ self ._handle_canceled_requests ()
954993 finished_requests = self ._handle_responses ()
955994 self .resource_manager .update_resources (scheduled_batch )
956995 if self .enable_kv_cache_events :
@@ -1006,12 +1045,12 @@ def _executor_loop_overlap(self):
10061045 with self ._profiler () as profile_step :
10071046 iter_start_time = time .time ()
10081047 iter_stats = None
1009- while not self .is_shutdown or len ( self . active_requests ) > 0 :
1048+ while not self .should_stop_processing :
10101049 profile_step ()
10111050 if self .enable_iter_perf_stats :
10121051 iter_start_time = time .time ()
10131052 new_requests = self ._fetch_new_requests ()
1014- if self .is_shutdown and len ( self . active_requests ) == 0 :
1053+ if self .should_stop_processing :
10151054 break
10161055
10171056 if self .kv_cache_transceiver :
@@ -1125,7 +1164,7 @@ def _process_previous_batch(self):
11251164 for req in self .previous_batch .ctx_transmission_reqs :
11261165 req .state = LlmRequestState .DISAGG_CONTEXT_TRANS_IN_PROGRESS
11271166
1128- self ._handle_cancelled_requests ()
1167+ self ._handle_canceled_requests ()
11291168 finished_requests = self ._handle_responses ()
11301169 scheduled_requests = self .previous_batch .sample_state .scheduled_requests
11311170 self .resource_manager .update_resources (scheduled_requests )
@@ -1200,13 +1239,11 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
12001239 total_num_active_requests = len (self .active_requests )
12011240 total_max_num_active_requests = self .max_num_active_requests
12021241
1203- timeout = None if total_num_active_requests == 0 else datetime . timedelta (
1204- 0 )
1242+ timeout = None if ( total_num_active_requests == 0 ) and len (
1243+ self . waiting_queue ) == 0 else datetime . timedelta ( 0 )
12051244 new_requests = []
12061245 if self .dist .rank == 0 :
1207- new_requests = _get_from_request_queue (
1208- self .request_queue , timeout ,
1209- total_max_num_active_requests - total_num_active_requests )
1246+ new_requests = _get_from_request_queue (self .request_queue , timeout )
12101247
12111248 if self .dist .rank == 0 :
12121249 py_logits_post_processors = self ._collect_py_objects_from_requests (
@@ -1229,21 +1266,28 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
12291266 # drop requests arriving after shutdown
12301267 valid_new_requests = []
12311268 for req_item in new_requests :
1232- if req_item .is_shutdown_request () :
1269+ if req_item .is_shutdown_request :
12331270 self .is_shutdown = True
12341271 break
1272+ elif req_item .is_canceled_request :
1273+ self .canceled_req_ids .append (req_item .id )
12351274 else :
12361275 valid_new_requests .append (req_item )
12371276 # Check if the beam width of the requests is equal to the max_beam_width
12381277 for req_item in valid_new_requests :
12391278 assert req_item .request .sampling_config .beam_width == self .max_beam_width , f"Request beam width { req_item .request .sampling_config .beam_width } is not equal to max_beam_width { self .max_beam_width } . This is not supported!"
1240- new_requests = valid_new_requests
12411279
12421280 if py_request_objects and (self .dist .tp_size > 1
12431281 or self .dist .has_pp ) and self .dist .rank > 0 :
12441282 for attr_name , req_obj_dict in py_request_objects :
1245- self ._attach_py_objects_to_requests (new_requests , attr_name ,
1246- req_obj_dict )
1283+ self ._attach_py_objects_to_requests (valid_new_requests ,
1284+ attr_name , req_obj_dict )
1285+
1286+ self .waiting_queue .extend (valid_new_requests )
1287+
1288+ new_requests = _get_from_waiting_queue (
1289+ self .waiting_queue ,
1290+ total_max_num_active_requests - total_num_active_requests )
12471291
12481292 if not self .enable_attention_dp :
12491293 self ._update_new_active_requests_queue_latency (new_requests )
@@ -1339,7 +1383,7 @@ def _collect_py_objects_from_requests(
13391383 """
13401384 req_id_to_obj = {}
13411385 for item in requests :
1342- if item .is_shutdown_request () :
1386+ if not item .is_normal_request :
13431387 continue
13441388 obj = getattr (item .request , attribute_name , None )
13451389 if obj is not None :
@@ -1926,41 +1970,28 @@ def _handle_errors(self, error_msg: Optional[str] = None):
19261970 def _terminate_request (self , request : LlmRequest ):
19271971 self .resource_manager .free_resources (request )
19281972
1929- @nvtx_range ("_handle_cancelled_requests" )
1930- def _handle_cancelled_requests (self ):
1931- #TODO: properly handle canceled ids in pp case
1932- if self .dist .has_tp :
1933- self .canceled_req_ids = self .dist .broadcast (self .canceled_req_ids ,
1934- root = 0 )
1935-
1973+ @nvtx_range ("_handle_canceled_requests" )
1974+ def _handle_canceled_requests (self ):
19361975 if len (self .canceled_req_ids ) == 0 :
19371976 return
19381977
1939- cancelled_responses = {}
1940- left_requests = []
1941- # Tracks canceled requests for proper handling in overlap mode during `sampler.update_requests`.
1942- self . canceled_requests = []
1978+ # cancel request in the waiting queue
1979+ self . waiting_queue = deque ( req for req in self . waiting_queue
1980+ if req . id not in self . canceled_req_ids )
1981+
19431982 for request in self .active_requests :
19441983 req_id = request .py_request_id
19451984 if req_id in self .canceled_req_ids :
1946- self ._terminate_request (request )
1985+ # Mark requests as finished, then, we reuse all existing code
1986+ # to clean up the KV cache resources.
19471987 request .finish_by_reason (FinishReason .CANCELLED )
19481988 request .decoding_iter = request .py_decoding_iter
1949- cancelled_responses [req_id ] = request .create_response (
1950- False , self .dist .rank )
1951- self .canceled_requests .append (request )
1952- self .canceled_req_ids .erase (req_id )
1953- else :
1954- left_requests .append (request )
1955- self .active_requests = left_requests
19561989
1957- # When enable attention dp, each rank does not have full copy of requests
1958- # so we need to remove the cancel requests not in the local rank
1959- self .canceled_req_ids .clear ()
1960-
1961- # enqueue the cancelled requests' responses as they are not
1962- # active_requests and be discarded in the sampler loop.
1963- self ._enqueue_responses (cancelled_responses )
1990+ if self .enable_attention_dp :
1991+ # TODO: revisit the cancel logic of attention dp
1992+ # When enable attention dp, each rank does not have full copy of requests
1993+ # so we need to remove the cancel requests not in the local rank
1994+ self .canceled_req_ids .clear ()
19641995
19651996 @nvtx_range ("_enqueue_responses" )
19661997 def _enqueue_responses (self , responses : Dict [int , LlmResponse ]):
0 commit comments