Skip to content

Commit 1c003f8

Browse files
committed
rebase to v1.0.0rc6 and works
1 parent 9ccc61d commit 1c003f8

File tree

4 files changed

+88
-19
lines changed

4 files changed

+88
-19
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from .sampler import Sampler, TorchSampler
1919

2020
SHUTDOWN_REQUEST_ID = -1
21+
UPDATE_WEIGHT_REQUEST_ID = -2
22+
SLEEP_REQUEST_ID = -3
23+
WAKEUP_REQUEST_ID = -4
2124

2225

2326
@dataclasses.dataclass
@@ -28,15 +31,33 @@ class RequestQueueItem:
2831
child_req_ids: Optional[list] = None
2932
is_canceled_request: bool = False
3033
query: Optional[list] = None # only used in `StarAttention`
34+
weight_ipc_handles: Optional[dict] = None
35+
sleep_level: Optional[int] = None
36+
wakeup_level: Optional[int] = None
3137

3238
@property
3339
def is_shutdown_request(self):
3440
return self.id == SHUTDOWN_REQUEST_ID
3541

3642
@property
3743
def is_normal_request(self):
38-
return not (self.is_shutdown_request or self.is_canceled_request)
44+
return self.id > 0 and not self.is_canceled_request
3945

46+
@property
47+
def is_update_weight_request(self):
48+
return self.id == UPDATE_WEIGHT_REQUEST_ID
49+
50+
@property
51+
def is_sleep_request(self):
52+
return self.id == SLEEP_REQUEST_ID
53+
54+
@property
55+
def is_wakeup_request(self):
56+
return self.id == WAKEUP_REQUEST_ID
57+
58+
@property
59+
def is_control_request(self):
60+
return self.is_update_weight_request or self.is_sleep_request or self.is_wakeup_request
4061

4162
class ExecutorRequestQueue:
4263
"""Handles fetching and processing of new requests from the request queue."""
@@ -66,6 +87,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
6687
self.new_active_requests_queue_latency_ms = 0
6788
self.is_shutdown = False
6889
self.should_exclude_last_generation_logits = False
90+
self.control_requests: List[RequestQueueItem] = []
6991

7092
def _get_from_request_queue(
7193
self,
@@ -226,6 +248,20 @@ def enqueue_cancel_request(self, req_id: int):
226248
self.request_queue.put(
227249
RequestQueueItem(req_id, is_canceled_request=True))
228250

251+
def enqueue_sleep_request(self, req_id: int, sleep_level: int):
252+
with self.enqueue_lock:
253+
print(f"enqueue_sleep_request: {req_id} {sleep_level}")
254+
self.request_queue.put(
255+
RequestQueueItem(req_id, sleep_level=sleep_level))
256+
257+
def enqueue_wakeup_request(self, req_id: int, wakeup_level: int):
258+
with self.enqueue_lock:
259+
self.request_queue.put(RequestQueueItem(req_id, wakeup_level=wakeup_level))
260+
261+
def enqueue_update_weight_request(self, req_id: int, weight_ipc_handles: dict):
262+
with self.enqueue_lock:
263+
self.request_queue.put(RequestQueueItem(req_id, weight_ipc_handles=weight_ipc_handles))
264+
229265
def enqueue_shutdown_request(self):
230266
with self.enqueue_lock:
231267
self.request_queue.put(RequestQueueItem(SHUTDOWN_REQUEST_ID))
@@ -431,6 +467,8 @@ def _validate_and_filter_requests(
431467
break
432468
elif req_item.is_canceled_request:
433469
self.canceled_req_ids.append(req_item.id)
470+
elif req_item.is_control_request:
471+
self.control_requests.append(req_item)
434472
else:
435473
valid_new_requests.append(req_item)
436474

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,32 @@ def cancel_request(self, id: int):
334334
"""
335335
self.executor_request_queue.enqueue_cancel_request(id)
336336

337+
def enqueue_sleep_request(self, id: int, sleep_level: int):
338+
"""
339+
Enqueue a sleep request with provided request id and sleep level
340+
Args:
341+
id (int): The request id for which to sleep
342+
sleep_level (int): The sleep level to apply to the request
343+
"""
344+
self.executor_request_queue.enqueue_sleep_request(id, sleep_level)
345+
346+
def enqueue_wakeup_request(self, id: int, wakeup_level: int):
347+
"""
348+
Enqueue a wakeup request with provided request id
349+
Args:
350+
id (int): The request id for which to wakeup
351+
"""
352+
self.executor_request_queue.enqueue_wakeup_request(id, wakeup_level)
353+
354+
def enqueue_update_weight_request(self, id: int, weight_ipc_handles: dict):
355+
"""
356+
Enqueue a update weight request with provided request id and weight ipc handles
357+
Args:
358+
id (int): The request id for which to update weight
359+
weight_ipc_handles (dict): The weight ipc handles to update
360+
"""
361+
self.executor_request_queue.enqueue_update_weight_request(id, weight_ipc_handles)
362+
337363
def shutdown(self):
338364
"""
339365
Signals the server to shutdown.
@@ -1080,19 +1106,19 @@ def update_weight_from_ipc_handles(self, handles):
10801106

10811107
def _sleep(self, sleep_request):
10821108
self.is_sleep_request = False
1083-
self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id)})
1109+
self._enqueue_responses([(sleep_request.id, LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id))])
10841110

10851111
def _wakeup(self, wakeup_request):
10861112
self.is_wakeup_request = False
1087-
self._enqueue_responses({wakeup_request.id: LlmResponse(request_id=wakeup_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=wakeup_request.id)})
1113+
self._enqueue_responses([(wakeup_request.id, LlmResponse(request_id=wakeup_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=wakeup_request.id))])
10881114

10891115
def _update_weight(self, update_weight_request):
10901116
self.is_update_weight_request = False
10911117

10921118
try:
10931119
self.update_weight_from_ipc_handles(update_weight_request.weight_ipc_handles)
10941120
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=update_weight_request.id)
1095-
self._enqueue_responses({update_weight_request.id: update_weight_response})
1121+
self._enqueue_responses([(update_weight_request.id, update_weight_response)])
10961122
except Exception as e:
10971123
print(
10981124
f"Error in update_weights_from_ipc_handles: {e}"
@@ -1101,6 +1127,20 @@ def _update_weight(self, update_weight_request):
11011127
#update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=False), is_final=True), client_id=update_weight_request.id)
11021128
#self._enqueue_responses({update_weight_request.id: update_weight_response})
11031129

1130+
def _handle_control_request(self):
1131+
if len(self.executor_request_queue.control_requests) > 0:
1132+
assert len(self.executor_request_queue.control_requests) == 1, f"control request should be the only request in the list, but got {len(self.executor_request_queue.control_requests)}"
1133+
control_request = self.executor_request_queue.control_requests.pop()
1134+
if (control_request.is_update_weight_request):
1135+
self._update_weight(control_request)
1136+
elif (control_request.is_sleep_request):
1137+
self._sleep(control_request)
1138+
elif (control_request.is_wakeup_request):
1139+
self._wakeup(control_request)
1140+
else:
1141+
assert False, "Invalid control request"
1142+
1143+
11041144
def _executor_loop_overlap(self):
11051145
torch.cuda.set_device(self.device_id)
11061146
# ensure the context is created, otherwise, some MPI calls will fail.
@@ -1122,20 +1162,10 @@ def _executor_loop_overlap(self):
11221162
iter_start_time = time.time()
11231163

11241164
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
1165+
self._handle_control_request()
1166+
11251167
if scheduled_batch is None:
11261168
break
1127-
if self.is_control_request:
1128-
self.is_control_request = False
1129-
assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}"
1130-
if (new_requests[0].is_update_weight_request()):
1131-
self._update_weight(new_requests[0])
1132-
elif (new_requests[0].is_sleep_request()):
1133-
self._sleep(new_requests[0])
1134-
elif (new_requests[0].is_wakeup_request()):
1135-
self._wakeup(new_requests[0])
1136-
else:
1137-
assert False, "Invalid control request"
1138-
continue
11391169

11401170
self._pause_requests(scheduled_batch.paused_requests)
11411171

tensorrt_llm/_torch/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def nvml_context() -> Generator[None, None, None]:
284284

285285
def device_id_to_physical_device_id(device_id: int) -> int:
286286
"""Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES."""
287+
import os
287288
if "CUDA_VISIBLE_DEVICES" in os.environ:
288289
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
289290
try:

tensorrt_llm/executor/worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,13 +460,13 @@ def _deduce_max_tokens(request: GenerationRequest,
460460
f"is larger than max_seq_len {executor_config.max_seq_len}")
461461
return default_max_tokens
462462
if request.is_weight_update_request():
463-
req_id = self.engine.enqueue_request(request, weight_ipc_handles=request.weight_ipc_handles)
463+
req_id = self.engine.enqueue_update_weight_request(request.id, weight_ipc_handles=request.weight_ipc_handles)
464464
return req_id
465465
elif request.is_sleep_request():
466-
req_id = self.engine.enqueue_request(request, sleep_level=request.sleep_level)
466+
req_id = self.engine.enqueue_sleep_request(request.id, sleep_level=request.sleep_level)
467467
return req_id
468468
elif request.is_wakeup_request():
469-
req_id = self.engine.enqueue_request(request, wakeup_level=request.wakeup_level)
469+
req_id = self.engine.enqueue_wakeup_request(request.id, wakeup_level=request.wakeup_level)
470470
return req_id
471471
try:
472472
executor_request = tllm.Request(

0 commit comments

Comments
 (0)