Skip to content

Commit ce39409

Browse files
authored
fix cancel request logic (#5800)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
1 parent a36ac45 commit ce39409

File tree

3 files changed

+87
-58
lines changed

3 files changed

+87
-58
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 86 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
import traceback
1111
import weakref
12-
from collections import namedtuple
12+
from collections import deque, namedtuple
1313
from contextlib import contextmanager
1414
from typing import Dict, List, Optional, Tuple, Union
1515

@@ -57,34 +57,62 @@
5757
class RequestQueueItem:
5858
id: int
5959
request: Optional[ExecutorRequest] = None
60+
is_canceled_request: bool = False
6061
query: Optional[list] = None # only used in `StarAttention`
6162

63+
@property
6264
def is_shutdown_request(self):
6365
return self.id == SHUTDOWN_REQUEST_ID
6466

67+
@property
68+
def is_normal_request(self):
69+
return not (self.is_shutdown_request or self.is_canceled_request)
6570

66-
def _get_from_request_queue(request_queue,
67-
timeout: Optional[datetime.timedelta],
68-
max_req_count: int) -> List[RequestQueueItem]:
71+
72+
def _get_from_request_queue(
73+
request_queue,
74+
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
6975
items = []
7076
timeout_secs = timeout.total_seconds() if timeout is not None else None
71-
req_count = 0
7277
try:
7378
if request_queue.empty() and (timeout_secs is None or timeout_secs > 0):
7479
# if queue is empty and want to wait, wait
7580
items.append(request_queue.get(timeout=timeout_secs))
7681
else:
7782
# if not empty or don't want to wait, just return all items in queue
78-
while req_count < max_req_count:
83+
while True:
7984
queue_item = request_queue.get_nowait()
8085
items.append(queue_item)
81-
if not queue_item.is_shutdown_request():
82-
req_count += 1
8386
except queue.Empty:
8487
pass
8588
return items
8689

8790

91+
def _get_from_waiting_queue(
92+
waiting_queue: deque[RequestQueueItem],
93+
max_req_count: int,
94+
) -> List[RequestQueueItem]:
95+
"""Safely extracts up to max_req_count items from a deque.
96+
97+
Args:
98+
waiting_queue: The queue to pop items from.
99+
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
100+
101+
Returns:
102+
List of retrieved items (may be shorter than max_req_count if queue empties first).
103+
"""
104+
# Edge case handling
105+
if max_req_count <= 0: # Handles negative/zero counts
106+
return []
107+
108+
items = []
109+
req_count = 0
110+
while req_count < max_req_count and waiting_queue:
111+
items.append(waiting_queue.popleft())
112+
req_count += 1
113+
return items
114+
115+
88116
@functools.cache
89117
def _load_iteration_indexes(env_var: str):
90118
spans = os.environ.get(env_var, None)
@@ -182,6 +210,7 @@ def __init__(self,
182210
self.device_id = torch.cuda.current_device()
183211
self.global_rank = global_mpi_rank()
184212
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
213+
self.waiting_queue: deque[RequestQueueItem] = deque()
185214

186215
# profile config
187216
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
@@ -251,7 +280,7 @@ def __init__(self,
251280
self.send_handles = [None] * self.num_micro_batches
252281

253282
self.inflight_req_ids = ReqIdsSet()
254-
self.canceled_req_ids = ReqIdsSet()
283+
self.canceled_req_ids = []
255284

256285
self.model_engine.warmup(self.resource_manager)
257286
if self.draft_model_engine is not None:
@@ -368,7 +397,12 @@ def cancel_request(self, id: int):
368397
Args:
369398
id (int): The request id for which to cancel the response
370399
"""
371-
self.canceled_req_ids.insert(id)
400+
try:
401+
self.enqueue_lock.acquire()
402+
self.request_queue.put(
403+
RequestQueueItem(id, is_canceled_request=True))
404+
finally:
405+
self.enqueue_lock.release()
372406

373407
def shutdown(self):
374408
"""
@@ -454,6 +488,11 @@ def enqueue_request(self,
454488
def set_gather_responses(self, gather_all_responses):
455489
self.gather_all_responses = gather_all_responses
456490

491+
@property
492+
def should_stop_processing(self):
493+
return self.is_shutdown and len(self.active_requests) == 0 and len(
494+
self.waiting_queue) == 0
495+
457496
@contextmanager
458497
def _profiler(self):
459498
it = -1
@@ -710,12 +749,12 @@ def _executor_loop_pp(self):
710749
with self._profiler() as profile_step:
711750
iter_start_time = time.time()
712751
iter_stats = None
713-
while not self.is_shutdown or len(self.active_requests) > 0:
752+
while not self.should_stop_processing:
714753
profile_step()
715754
if self.enable_iter_perf_stats:
716755
iter_start_time = time.time()
717756
new_requests = self._fetch_new_requests()
718-
if self.is_shutdown and len(self.active_requests) == 0:
757+
if self.should_stop_processing:
719758
break
720759

721760
if self.enable_iter_perf_stats:
@@ -839,7 +878,7 @@ def _executor_loop_pp(self):
839878
if previous_batch is not None:
840879
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
841880
self._update_requests(previous_batch.sample_state)
842-
self._handle_cancelled_requests()
881+
self._handle_canceled_requests()
843882
finished_requests = self._handle_responses()
844883
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
845884
self.resource_manager.update_resources(
@@ -861,12 +900,12 @@ def _executor_loop(self):
861900
sample_state = None
862901
iter_start_time = time.time()
863902
iter_stats = None
864-
while not self.is_shutdown or len(self.active_requests) > 0:
903+
while not self.should_stop_processing:
865904
profile_step()
866905
if self.enable_iter_perf_stats:
867906
iter_start_time = time.time()
868907
new_requests = self._fetch_new_requests()
869-
if self.is_shutdown and len(self.active_requests) == 0:
908+
if self.should_stop_processing:
870909
break
871910

872911
if self.kv_cache_transceiver:
@@ -950,7 +989,7 @@ def _executor_loop(self):
950989
for req in ctx_transmission_reqs:
951990
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
952991

953-
self._handle_cancelled_requests()
992+
self._handle_canceled_requests()
954993
finished_requests = self._handle_responses()
955994
self.resource_manager.update_resources(scheduled_batch)
956995
if self.enable_kv_cache_events:
@@ -1006,12 +1045,12 @@ def _executor_loop_overlap(self):
10061045
with self._profiler() as profile_step:
10071046
iter_start_time = time.time()
10081047
iter_stats = None
1009-
while not self.is_shutdown or len(self.active_requests) > 0:
1048+
while not self.should_stop_processing:
10101049
profile_step()
10111050
if self.enable_iter_perf_stats:
10121051
iter_start_time = time.time()
10131052
new_requests = self._fetch_new_requests()
1014-
if self.is_shutdown and len(self.active_requests) == 0:
1053+
if self.should_stop_processing:
10151054
break
10161055

10171056
if self.kv_cache_transceiver:
@@ -1125,7 +1164,7 @@ def _process_previous_batch(self):
11251164
for req in self.previous_batch.ctx_transmission_reqs:
11261165
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
11271166

1128-
self._handle_cancelled_requests()
1167+
self._handle_canceled_requests()
11291168
finished_requests = self._handle_responses()
11301169
scheduled_requests = self.previous_batch.sample_state.scheduled_requests
11311170
self.resource_manager.update_resources(scheduled_requests)
@@ -1200,13 +1239,11 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
12001239
total_num_active_requests = len(self.active_requests)
12011240
total_max_num_active_requests = self.max_num_active_requests
12021241

1203-
timeout = None if total_num_active_requests == 0 else datetime.timedelta(
1204-
0)
1242+
timeout = None if (total_num_active_requests == 0) and len(
1243+
self.waiting_queue) == 0 else datetime.timedelta(0)
12051244
new_requests = []
12061245
if self.dist.rank == 0:
1207-
new_requests = _get_from_request_queue(
1208-
self.request_queue, timeout,
1209-
total_max_num_active_requests - total_num_active_requests)
1246+
new_requests = _get_from_request_queue(self.request_queue, timeout)
12101247

12111248
if self.dist.rank == 0:
12121249
py_logits_post_processors = self._collect_py_objects_from_requests(
@@ -1229,21 +1266,28 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
12291266
# drop requests arriving after shutdown
12301267
valid_new_requests = []
12311268
for req_item in new_requests:
1232-
if req_item.is_shutdown_request():
1269+
if req_item.is_shutdown_request:
12331270
self.is_shutdown = True
12341271
break
1272+
elif req_item.is_canceled_request:
1273+
self.canceled_req_ids.append(req_item.id)
12351274
else:
12361275
valid_new_requests.append(req_item)
12371276
# Check if the beam width of the requests is equal to the max_beam_width
12381277
for req_item in valid_new_requests:
12391278
assert req_item.request.sampling_config.beam_width == self.max_beam_width, f"Request beam width {req_item.request.sampling_config.beam_width} is not equal to max_beam_width {self.max_beam_width}. This is not supported!"
1240-
new_requests = valid_new_requests
12411279

12421280
if py_request_objects and (self.dist.tp_size > 1
12431281
or self.dist.has_pp) and self.dist.rank > 0:
12441282
for attr_name, req_obj_dict in py_request_objects:
1245-
self._attach_py_objects_to_requests(new_requests, attr_name,
1246-
req_obj_dict)
1283+
self._attach_py_objects_to_requests(valid_new_requests,
1284+
attr_name, req_obj_dict)
1285+
1286+
self.waiting_queue.extend(valid_new_requests)
1287+
1288+
new_requests = _get_from_waiting_queue(
1289+
self.waiting_queue,
1290+
total_max_num_active_requests - total_num_active_requests)
12471291

12481292
if not self.enable_attention_dp:
12491293
self._update_new_active_requests_queue_latency(new_requests)
@@ -1339,7 +1383,7 @@ def _collect_py_objects_from_requests(
13391383
"""
13401384
req_id_to_obj = {}
13411385
for item in requests:
1342-
if item.is_shutdown_request():
1386+
if not item.is_normal_request:
13431387
continue
13441388
obj = getattr(item.request, attribute_name, None)
13451389
if obj is not None:
@@ -1926,41 +1970,28 @@ def _handle_errors(self, error_msg: Optional[str] = None):
19261970
def _terminate_request(self, request: LlmRequest):
19271971
self.resource_manager.free_resources(request)
19281972

1929-
@nvtx_range("_handle_cancelled_requests")
1930-
def _handle_cancelled_requests(self):
1931-
#TODO: properly handle canceled ids in pp case
1932-
if self.dist.has_tp:
1933-
self.canceled_req_ids = self.dist.broadcast(self.canceled_req_ids,
1934-
root=0)
1935-
1973+
@nvtx_range("_handle_canceled_requests")
1974+
def _handle_canceled_requests(self):
19361975
if len(self.canceled_req_ids) == 0:
19371976
return
19381977

1939-
cancelled_responses = {}
1940-
left_requests = []
1941-
# Tracks canceled requests for proper handling in overlap mode during `sampler.update_requests`.
1942-
self.canceled_requests = []
1978+
# cancel request in the waiting queue
1979+
self.waiting_queue = deque(req for req in self.waiting_queue
1980+
if req.id not in self.canceled_req_ids)
1981+
19431982
for request in self.active_requests:
19441983
req_id = request.py_request_id
19451984
if req_id in self.canceled_req_ids:
1946-
self._terminate_request(request)
1985+
# Mark requests as finished, then, we reuse all existing code
1986+
# to clean up the KV cache resources.
19471987
request.finish_by_reason(FinishReason.CANCELLED)
19481988
request.decoding_iter = request.py_decoding_iter
1949-
cancelled_responses[req_id] = request.create_response(
1950-
False, self.dist.rank)
1951-
self.canceled_requests.append(request)
1952-
self.canceled_req_ids.erase(req_id)
1953-
else:
1954-
left_requests.append(request)
1955-
self.active_requests = left_requests
19561989

1957-
# When enable attention dp, each rank does not have full copy of requests
1958-
# so we need to remove the cancel requests not in the local rank
1959-
self.canceled_req_ids.clear()
1960-
1961-
# enqueue the cancelled requests' responses as they are not
1962-
# active_requests and be discarded in the sampler loop.
1963-
self._enqueue_responses(cancelled_responses)
1990+
if self.enable_attention_dp:
1991+
# TODO: revisit the cancel logic of attention dp
1992+
# When enable attention dp, each rank does not have full copy of requests
1993+
# so we need to remove the cancel requests not in the local rank
1994+
self.canceled_req_ids.clear()
19641995

19651996
@nvtx_range("_enqueue_responses")
19661997
def _enqueue_responses(self, responses: Dict[int, LlmResponse]):

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
124124
logits_dim = logits.dim()
125125
if logits_dim == 1:
126126
logits = logits.unsqueeze(0)
127-
assert logits_dim == 2, "logits should be 2D [batch_size, vocab_size]"
127+
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
128128

129129
# sort the logits of each sample in descending order
130130
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)

tests/unittest/llmapi/apps/_test_openai_misc.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ async def test_request_cancellation(server: RemoteOpenAIServer,
8686
model_name: str):
8787
# clunky test: send an ungodly amount of load in with short timeouts
8888
# then ensure that it still responds quickly afterwards
89-
pytest.skip("https://nvbugs/5310314")
90-
9189
chat_input = [{"role": "user", "content": "Write a long story"}]
9290
client = server.get_async_client(timeout=0.5, max_retries=3)
9391
tasks = []

0 commit comments

Comments
 (0)