Skip to content

Commit dfb3dd1

Browse files
committed
update weights works, verl integration tp/dp works.
1 parent cfcb97a commit dfb3dd1

File tree

10 files changed

+320
-17
lines changed

10 files changed

+320
-17
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ def __getattr__(self, item):
246246
return getattr(result, item)
247247

248248
def deserialize(self):
249-
self._result = tensorrt_llm.bindings.executor.deserialize_result(
250-
self._result)
249+
if self._result is not None:
250+
self._result = tensorrt_llm.bindings.executor.deserialize_result(
251+
self._result)
251252

252253

253254
@dataclass

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,6 @@ def init_meta_tensor(t: torch.Tensor):
10871087
weights = load_weights(model.llm_checkpoint_dir)
10881088
else:
10891089
weights = load_weights(checkpoint_dir)
1090-
10911090
model.load_weights(weights)
10921091

10931092
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 136 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1919
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
20+
from tensorrt_llm._torch.utils import get_device_uuid
2021
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
2122
is_trace_enabled, nvtx_range, trace_func)
2223
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
@@ -33,7 +34,7 @@
3334
from ..speculative.drafter import Drafter
3435
from .kv_cache_transceiver import KvCacheTransceiver
3536
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
36-
LlmResponse, executor_request_to_llm_request)
37+
LlmResponse, LlmResult, executor_request_to_llm_request)
3738
from .model_engine import ModelEngine
3839
from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler
3940
from .scheduler import RequestScheduler, ScheduledRequests
@@ -51,6 +52,9 @@
5152
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
5253

5354
SHUTDOWN_REQUEST_ID = -1
55+
UPDATE_WEIGHT_REQUEST_ID = -2
56+
SLEEP_REQUEST_ID = -3
57+
WAKEUP_REQUEST_ID = -4
5458

5559

5660
@dataclasses.dataclass
@@ -59,15 +63,25 @@ class RequestQueueItem:
5963
request: Optional[ExecutorRequest] = None
6064
is_canceled_request: bool = False
6165
query: Optional[list] = None # only used in `StarAttention`
66+
weight_ipc_handles: Optional[dict] = None
67+
sleep_level: Optional[int] = None
68+
wakeup_level: Optional[int] = None
6269

6370
@property
6471
def is_shutdown_request(self):
6572
return self.id == SHUTDOWN_REQUEST_ID
6673

6774
@property
6875
def is_normal_request(self):
69-
return not (self.is_shutdown_request or self.is_canceled_request)
76+
return self.id > 0 and not self.is_canceled_request
77+
def is_update_weight_request(self):
78+
return self.id == UPDATE_WEIGHT_REQUEST_ID
7079

80+
def is_sleep_request(self):
81+
return self.id == SLEEP_REQUEST_ID
82+
83+
def is_wakeup_request(self):
84+
return self.id == WAKEUP_REQUEST_ID
7185

7286
def _get_from_request_queue(
7387
request_queue,
@@ -244,6 +258,7 @@ def __init__(self,
244258
self.num_fetch_requests_cur_rank = 0
245259
self.num_fetch_requests = 0
246260
self.shutdown_event = threading.Event()
261+
self.request_accumulator: List[RequestQueueItem] = []
247262

248263
# response used data
249264
self.response_lock = threading.Lock()
@@ -287,6 +302,8 @@ def __init__(self,
287302
self.draft_model_engine.warmup(self.resource_manager)
288303

289304
self.is_shutdown = False
305+
self.is_control_request = False
306+
self.control_request_id = 0
290307

291308
self.stats_lock = threading.Lock()
292309
self.stats = []
@@ -465,7 +482,10 @@ def wait_shutdown(self):
465482

466483
def enqueue_request(self,
467484
request: ExecutorRequest,
468-
query: Optional[List] = None):
485+
query: Optional[List] = None,
486+
weight_ipc_handles: Optional[dict] = None,
487+
sleep_level: Optional[int] = None,
488+
wakeup_level: Optional[int] = None):
469489
"""
470490
Enqueue a new request, query is only used in `StarAttention`.
471491
"""
@@ -476,10 +496,17 @@ def enqueue_request(self,
476496
if self.enable_iter_perf_stats:
477497
self.start_times[req_id] = time.time()
478498

479-
if query is not None:
499+
if weight_ipc_handles is not None:
500+
self.request_queue.put(RequestQueueItem(UPDATE_WEIGHT_REQUEST_ID, None, False, None, weight_ipc_handles))
501+
elif sleep_level is not None:
502+
self.request_queue.put(RequestQueueItem(SLEEP_REQUEST_ID, None, False, None, None, sleep_level))
503+
elif wakeup_level is not None:
504+
self.request_queue.put(RequestQueueItem(WAKEUP_REQUEST_ID, None, False, None, None, None, wakeup_level))
505+
elif query is not None:
480506
self.request_queue.put(RequestQueueItem(req_id, request, query))
481507
else:
482508
self.request_queue.put(RequestQueueItem(req_id, request))
509+
#self.request_queue.put(RequestQueueItem(req_id, request, False, query, weight_ipc_handles, sleep_level, wakeup_level))
483510
self.next_req_id += 1
484511
finally:
485512
self.enqueue_lock.release()
@@ -756,6 +783,18 @@ def _executor_loop_pp(self):
756783
new_requests = self._fetch_new_requests()
757784
if self.should_stop_processing:
758785
break
786+
if self.is_control_request:
787+
self.is_control_request = False
788+
assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}"
789+
if (new_requests[0].is_update_weight_request()):
790+
self._update_weight(new_requests[0])
791+
elif (new_requests[0].is_sleep_request()):
792+
self._sleep(new_requests[0])
793+
elif (new_requests[0].is_wakeup_request()):
794+
self._wakeup(new_requests[0])
795+
else:
796+
assert False, "Invalid control request"
797+
continue
759798

760799
if self.enable_iter_perf_stats:
761800
iter_stats = self._get_init_iter_stats(
@@ -907,6 +946,18 @@ def _executor_loop(self):
907946
new_requests = self._fetch_new_requests()
908947
if self.should_stop_processing:
909948
break
949+
if self.is_control_request:
950+
self.is_control_request = False
951+
assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}"
952+
if (new_requests[0].is_update_weight_request()):
953+
self._update_weight(new_requests[0])
954+
elif (new_requests[0].is_sleep_request()):
955+
self._sleep(new_requests[0])
956+
elif (new_requests[0].is_wakeup_request()):
957+
self._wakeup(new_requests[0])
958+
else:
959+
assert False, "Invalid control request"
960+
continue
910961

911962
if self.kv_cache_transceiver:
912963
self._check_disagg_gen_transfer_status()
@@ -1033,6 +1084,50 @@ def _prepare_draft_requests(self):
10331084
logger.error(f"Encountered an error in decode: {error_msg}")
10341085
self._handle_errors(error_msg)
10351086

1087+
def _sleep(self, sleep_request):
1088+
self.is_sleep_request = False
1089+
self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=None, is_final=True), client_id=sleep_request.id)})
1090+
1091+
def _wakeup(self, wakeup_request):
1092+
self.is_wakeup_request = False
1093+
self._enqueue_responses({wakeup_request.id: LlmResponse(request_id=wakeup_request.id, result=LlmResult(result=None, py_result=None, is_final=True), client_id=wakeup_request.id)})
1094+
1095+
def _update_weight(self, update_weight_request):
1096+
self.is_update_weight_request = False
1097+
1098+
try:
1099+
# Get handles for this device
1100+
device_uuid = get_device_uuid(self.device_id)
1101+
handles = update_weight_request.weight_ipc_handles[device_uuid]
1102+
weights = {}
1103+
1104+
# Process each handle to get the tensor
1105+
i = 0
1106+
for name, handle in handles:
1107+
func, args = handle
1108+
list_args = list(args)
1109+
# Update device ID to match the current device
1110+
list_args[6] = self.device_id
1111+
tensor = func(*list_args)
1112+
if i % 2 == 0:
1113+
weights[name] = tensor
1114+
else:
1115+
weights[name] = tensor # + 1.0
1116+
i += 1
1117+
1118+
# Load weights into the model
1119+
self.model_engine.model.load_weights(weights)
1120+
1121+
torch.cuda.synchronize()
1122+
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=None, is_final=True), client_id=update_weight_request.id)
1123+
self._enqueue_responses({update_weight_request.id: update_weight_response})
1124+
except Exception as e:
1125+
print(
1126+
f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}"
1127+
)
1128+
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=None, is_final=True), client_id=update_weight_request.id)
1129+
self._enqueue_responses({update_weight_request.id: update_weight_response})
1130+
10361131
def _executor_loop_overlap(self):
10371132
torch.cuda.set_device(self.device_id)
10381133
if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver:
@@ -1052,6 +1147,18 @@ def _executor_loop_overlap(self):
10521147
new_requests = self._fetch_new_requests()
10531148
if self.should_stop_processing:
10541149
break
1150+
if self.is_control_request:
1151+
self.is_control_request = False
1152+
assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}"
1153+
if (new_requests[0].is_update_weight_request()):
1154+
self._update_weight(new_requests[0])
1155+
elif (new_requests[0].is_sleep_request()):
1156+
self._sleep(new_requests[0])
1157+
elif (new_requests[0].is_wakeup_request()):
1158+
self._wakeup(new_requests[0])
1159+
else:
1160+
assert False, "Invalid control request"
1161+
continue
10551162

10561163
if self.kv_cache_transceiver:
10571164
self._check_disagg_gen_transfer_status()
@@ -1263,20 +1370,43 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
12631370
new_requests, py_request_objects = self._broadcast_new_requests(
12641371
new_requests, py_request_objects)
12651372

1373+
self.request_accumulator.extend(new_requests)
1374+
12661375
# drop requests arriving after shutdown
12671376
valid_new_requests = []
1268-
for req_item in new_requests:
1377+
find_control_request = False
1378+
for i, req_item in enumerate(self.request_accumulator):
12691379
if req_item.is_shutdown_request:
12701380
self.is_shutdown = True
1381+
find_control_request = True
1382+
break
1383+
if req_item.is_update_weight_request() or req_item.is_sleep_request() or req_item.is_wakeup_request():
1384+
find_control_request = True
1385+
self.control_request_id = req_item.id
12711386
break
12721387
elif req_item.is_canceled_request:
12731388
self.canceled_req_ids.append(req_item.id)
1389+
1390+
if (find_control_request):
1391+
if (i==0):
1392+
if not self.is_shutdown:
1393+
valid_new_requests = self.request_accumulator[:1]
1394+
self.is_control_request = True
1395+
self.request_accumulator = self.request_accumulator[1:]
1396+
return valid_new_requests
12741397
else:
1275-
valid_new_requests.append(req_item)
1398+
valid_new_requests = self.request_accumulator[:i]
1399+
self.request_accumulator = self.request_accumulator[i:]
1400+
else:
1401+
valid_new_requests = self.request_accumulator
1402+
self.request_accumulator = []
1403+
12761404
# Check if the beam width of the requests is equal to the max_beam_width
12771405
for req_item in valid_new_requests:
12781406
assert req_item.request.sampling_config.beam_width == self.max_beam_width, f"Request beam width {req_item.request.sampling_config.beam_width} is not equal to max_beam_width {self.max_beam_width}. This is not supported!"
12791407

1408+
new_requests = valid_new_requests
1409+
12801410
if py_request_objects and (self.dist.tp_size > 1
12811411
or self.dist.has_pp) and self.dist.rank > 0:
12821412
for attr_name, req_obj_dict in py_request_objects:

tensorrt_llm/_torch/utils.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import threading
44
from dataclasses import dataclass
55
from enum import Enum
6-
from typing import Dict, List
6+
from typing import Dict, List, Generator
77

88
import torch
9+
import pynvml
910

1011
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
1112
from tensorrt_llm.math_utils import ceil_div, pad_up
@@ -259,3 +260,61 @@ def set_piecewise_cuda_graph_flag(enable: bool):
259260
def get_piecewise_cuda_graph_flag() -> bool:
260261
global _enable_piecewise_cuda_graph
261262
return _enable_piecewise_cuda_graph
263+
264+
265+
@contextlib.contextmanager
266+
def nvml_context() -> Generator[None, None, None]:
267+
"""Context manager for NVML initialization and shutdown.
268+
269+
Raises:
270+
RuntimeError: If NVML initialization fails
271+
"""
272+
try:
273+
pynvml.nvmlInit()
274+
yield
275+
except pynvml.NVMLError as e:
276+
raise RuntimeError(f"Failed to initialize NVML: {e}")
277+
finally:
278+
try:
279+
pynvml.nvmlShutdown()
280+
except:
281+
pass
282+
283+
def device_id_to_physical_device_id(device_id: int) -> int:
284+
"""Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES."""
285+
if "CUDA_VISIBLE_DEVICES" in os.environ:
286+
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
287+
try:
288+
physical_device_id = int(device_ids[device_id])
289+
return physical_device_id
290+
except ValueError:
291+
raise RuntimeError(
292+
f"Failed to convert logical device ID {device_id} to physical device ID. Available devices are: {device_ids}."
293+
)
294+
else:
295+
return device_id
296+
297+
def get_device_uuid(device_idx: int) -> str:
298+
"""Get the UUID of a CUDA device using NVML."""
299+
# Convert logical device index to physical device index
300+
301+
global_device_idx = device_id_to_physical_device_id(device_idx)
302+
303+
# Get the device handle and UUID
304+
with nvml_context():
305+
try:
306+
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
307+
uuid = pynvml.nvmlDeviceGetUUID(handle)
308+
# Ensure the UUID is returned as a string, not bytes
309+
if isinstance(uuid, bytes):
310+
return uuid.decode("utf-8")
311+
elif isinstance(uuid, str):
312+
return uuid
313+
else:
314+
raise RuntimeError(
315+
f"Unexpected UUID type: {type(uuid)} for device {device_idx} (global index: {global_device_idx})"
316+
)
317+
except pynvml.NVMLError as e:
318+
raise RuntimeError(
319+
f"Failed to get device UUID for device {device_idx} (global index: {global_device_idx}): {e}"
320+
)

tensorrt_llm/executor/executor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,25 @@ def generate(
201201

202202
return futures
203203

204+
def async_update_weights_from_ipc_handles(self, handles: dict):
205+
update_weights_request = GenerationRequest([], SamplingParams(end_id=0))
206+
update_weights_request.set_weight_ipc_handles(handles)
207+
result = self.submit(update_weights_request)
208+
return result
209+
210+
def async_sleep(self, level: int = 1):
211+
sleep_request = GenerationRequest([], SamplingParams(end_id=0))
212+
sleep_request.set_sleep_level(level)
213+
result = self.submit(sleep_request)
214+
return result
215+
216+
def async_wakeup(self):
217+
sleep_request = GenerationRequest([], SamplingParams(end_id=0))
218+
sleep_request.set_wakeup_level(1)
219+
result = self.submit(sleep_request)
220+
return result
221+
222+
204223
def _get_next_client_id(self):
205224
# (self._last_client_id + 1) % UINT64_MAX
206225
self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1)

0 commit comments

Comments
 (0)