Skip to content

Commit 9ccc61d

Browse files
committed
rebase to v1.0.0rc6
add success into GenerationResult enable_block_reuse=False; clean device uuid code enable partial load align interfaces with ray branch stream update weights verified ray + trtllm works resolve ray conflit raise execption in update_weights add gate/up bundled update add gate/up bundled update works
1 parent a16ba64 commit 9ccc61d

File tree

13 files changed

+944
-21
lines changed

13 files changed

+944
-21
lines changed

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,9 @@ def load_single_module(name, module):
804804
for new_name in params_map[names[-1]]:
805805
fw = filter_weights('.'.join(names[:-1] + [new_name]),
806806
weights)
807+
# tmp fixes to enable partial updates in old path
808+
if not fw:
809+
continue
807810
if new_name in ['k_proj', 'v_proj']:
808811
num_kv_heads_list = [num_kv_heads
809812
] * len(fw) if isinstance(
@@ -820,15 +823,18 @@ def load_single_module(name, module):
820823
}
821824

822825
module_weights.append(fw)
823-
module.load_weights(weights=module_weights)
826+
if module_weights:
827+
module.load_weights(weights=module_weights)
828+
824829
else:
825830
module_weights = filter_weights(name, weights)
826-
if hasattr(module, 'load_weights'):
827-
module.load_weights(weights=[module_weights])
828-
else:
829-
for n, p in module._parameters.items():
830-
if p is not None:
831-
p.data.copy_(module_weights[n][:])
831+
if module_weights:
832+
if hasattr(module, 'load_weights'):
833+
module.load_weights(weights=[module_weights])
834+
else:
835+
for n, p in module._parameters.items():
836+
if p is not None:
837+
p.data.copy_(module_weights[n][:])
832838

833839
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
834840
False) in ["True", "true", "1", "yes", "y"]:

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,16 @@ def __init__(self,
164164
return_log_probs: bool = False,
165165
return_context_logits: bool = False,
166166
return_generation_logits: bool = False,
167-
exclude_last_generation_logits: bool = False):
167+
exclude_last_generation_logits: bool = False,
168+
success: bool = False):
168169
self._streaming = streaming
169170
self._context_logits = LogitsStorage(
170171
prompt_len, use_device_memory) if return_context_logits else None
171172
self._generation_logits = LogitsStorage(
172173
max_new_tokens, use_device_memory, exclude_last_generation_logits
173174
) if return_generation_logits else None
174175
self._log_probs = LogProbStorage() if return_log_probs else None
176+
self._success = success
175177

176178
def append_context_logits(self, context_logits: torch.Tensor):
177179
if self._context_logits:
@@ -247,8 +249,9 @@ def __getattr__(self, item):
247249
return getattr(result, item)
248250

249251
def deserialize(self):
250-
self._result = tensorrt_llm.bindings.executor.deserialize_result(
251-
self._result)
252+
if self._result is not None:
253+
self._result = tensorrt_llm.bindings.executor.deserialize_result(
254+
self._result)
252255

253256

254257
@dataclass

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1717
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
18+
from tensorrt_llm._torch.utils import get_device_uuid
1819
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
1920
is_trace_enabled, nvtx_range, trace_func)
2021
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
@@ -35,7 +36,7 @@
3536
from .guided_decoder import GuidedDecoder
3637
from .kv_cache_transceiver import KvCacheTransceiver
3738
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
38-
LlmResponse)
39+
LlmResponse, LlmResult, executor_request_to_llm_request, PyResult)
3940
from .model_engine import ModelEngine
4041
from .sampler import Sampler, SampleState, SampleStateTensors
4142
from .scheduler import RequestScheduler, ScheduledRequests
@@ -184,6 +185,7 @@ def __init__(self,
184185
self.num_fetch_requests_cur_rank = 0
185186
self.num_fetch_requests = 0
186187
self.shutdown_event = threading.Event()
188+
self.request_accumulator: List[RequestQueueItem] = []
187189

188190
# response used data
189191
self.response_lock = threading.Lock()
@@ -235,6 +237,8 @@ def __init__(self,
235237
)
236238
self.executor_request_queue.set_exclude_last_generation_logits(
237239
self.disable_overlap_scheduler, self.sampler)
240+
self.is_control_request = False
241+
self.control_request_id = 0
238242

239243
self.stats_lock = threading.Lock()
240244
self.stats = []
@@ -383,12 +387,29 @@ def wait_shutdown(self):
383387

384388
def enqueue_request(self,
385389
request: ExecutorRequest,
386-
query: Optional[List] = None) -> int:
390+
query: Optional[List] = None,
391+
weight_ipc_handles: Optional[dict] = None,
392+
sleep_level: Optional[int] = None,
393+
wakeup_level: Optional[int] = None) -> int:
387394
"""
388395
Enqueue a new request, query is only used in `StarAttention`.
389396
"""
390397
req_id = self.executor_request_queue.enqueue_request(request, query)
391398

399+
## if weight_ipc_handles is not None:
400+
## self.request_queue.put(RequestQueueItem(UPDATE_WEIGHT_REQUEST_ID, None, False, None, weight_ipc_handles))
401+
## elif sleep_level is not None:
402+
## self.request_queue.put(RequestQueueItem(SLEEP_REQUEST_ID, None, False, None, None, sleep_level))
403+
## elif wakeup_level is not None:
404+
## self.request_queue.put(RequestQueueItem(WAKEUP_REQUEST_ID, None, False, None, None, None, wakeup_level))
405+
## elif query is not None:
406+
## self.request_queue.put(RequestQueueItem(req_id, request, query))
407+
## else:
408+
## self.request_queue.put(RequestQueueItem(req_id, request))
409+
## #self.request_queue.put(RequestQueueItem(req_id, request, False, query, weight_ipc_handles, sleep_level, wakeup_level))
410+
## self.next_req_id += 1
411+
## finally:
412+
## self.enqueue_lock.release()
392413
return req_id
393414

394415
def set_gather_responses(self, gather_all_responses):
@@ -666,6 +687,18 @@ def _executor_loop_pp(self):
666687
new_requests = self._fetch_and_activate_new_requests()
667688
if self.should_stop_processing:
668689
break
690+
if self.is_control_request:
691+
self.is_control_request = False
692+
assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}"
693+
if (new_requests[0].is_update_weight_request()):
694+
self._update_weight(new_requests[0])
695+
elif (new_requests[0].is_sleep_request()):
696+
self._sleep(new_requests[0])
697+
elif (new_requests[0].is_wakeup_request()):
698+
self._wakeup(new_requests[0])
699+
else:
700+
assert False, "Invalid control request"
701+
continue
669702

670703
if self.kv_cache_transceiver:
671704
self._check_disagg_gen_transfer_status()
@@ -914,6 +947,18 @@ def _executor_loop(self):
914947
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
915948
if scheduled_batch is None:
916949
break
950+
if self.is_control_request:
951+
self.is_control_request = False
952+
assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}"
953+
if (new_requests[0].is_update_weight_request()):
954+
self._update_weight(new_requests[0])
955+
elif (new_requests[0].is_sleep_request()):
956+
self._sleep(new_requests[0])
957+
elif (new_requests[0].is_wakeup_request()):
958+
self._wakeup(new_requests[0])
959+
else:
960+
assert False, "Invalid control request"
961+
continue
917962

918963
self._pause_requests(scheduled_batch.paused_requests)
919964

@@ -995,6 +1040,67 @@ def _prepare_draft_requests(self):
9951040
logger.error(f"Encountered an error in decode: {error_msg}")
9961041
self._handle_errors(error_msg)
9971042

1043+
def update_weights(self, weights):
1044+
# Load weights into the model
1045+
self.model_engine.model.load_weights(weights)
1046+
torch.cuda.synchronize()
1047+
1048+
# TODO: reset prefix cache
1049+
1050+
def update_weight_from_ipc_handles(self, handles):
1051+
"""
1052+
Update model weights from IPC handles.
1053+
1054+
Args:
1055+
ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles.
1056+
{device_uuid: all_handles}
1057+
"""
1058+
from tensorrt_llm._torch.utils import get_device_uuid
1059+
device_uuid = get_device_uuid(self.device_id)
1060+
1061+
if device_uuid not in handles:
1062+
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")
1063+
1064+
try:
1065+
weights = {}
1066+
all_handles = handles[device_uuid]
1067+
1068+
for param_name, tensor_handle in all_handles:
1069+
func, args = tensor_handle
1070+
list_args = list(args)
1071+
list_args[6] = self.device_id # Set target device
1072+
tensor = func(*list_args)
1073+
weights[param_name] = tensor
1074+
1075+
self.update_weights(weights)
1076+
1077+
except Exception as e:
1078+
logger.error(f"failed to update weights from ipc handles: {e}")
1079+
raise e
1080+
1081+
def _sleep(self, sleep_request):
1082+
self.is_sleep_request = False
1083+
self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id)})
1084+
1085+
def _wakeup(self, wakeup_request):
1086+
self.is_wakeup_request = False
1087+
self._enqueue_responses({wakeup_request.id: LlmResponse(request_id=wakeup_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=wakeup_request.id)})
1088+
1089+
def _update_weight(self, update_weight_request):
1090+
self.is_update_weight_request = False
1091+
1092+
try:
1093+
self.update_weight_from_ipc_handles(update_weight_request.weight_ipc_handles)
1094+
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=update_weight_request.id)
1095+
self._enqueue_responses({update_weight_request.id: update_weight_response})
1096+
except Exception as e:
1097+
print(
1098+
f"Error in update_weights_from_ipc_handles: {e}"
1099+
)
1100+
raise e
1101+
#update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=False), is_final=True), client_id=update_weight_request.id)
1102+
#self._enqueue_responses({update_weight_request.id: update_weight_response})
1103+
9981104
def _executor_loop_overlap(self):
9991105
torch.cuda.set_device(self.device_id)
10001106
# ensure the context is created, otherwise, some MPI calls will fail.
@@ -1018,6 +1124,18 @@ def _executor_loop_overlap(self):
10181124
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
10191125
if scheduled_batch is None:
10201126
break
1127+
if self.is_control_request:
1128+
self.is_control_request = False
1129+
assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}"
1130+
if (new_requests[0].is_update_weight_request()):
1131+
self._update_weight(new_requests[0])
1132+
elif (new_requests[0].is_sleep_request()):
1133+
self._sleep(new_requests[0])
1134+
elif (new_requests[0].is_wakeup_request()):
1135+
self._wakeup(new_requests[0])
1136+
else:
1137+
assert False, "Invalid control request"
1138+
continue
10211139

10221140
self._pause_requests(scheduled_batch.paused_requests)
10231141

tensorrt_llm/_torch/utils.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import threading
33
from dataclasses import dataclass
44
from enum import Enum
5-
from typing import Dict, List
5+
from typing import Dict, List, Generator
66

77
import torch
8+
import pynvml
89

910
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
1011
from tensorrt_llm.math_utils import ceil_div, pad_up
@@ -261,3 +262,73 @@ def set_piecewise_cuda_graph_flag(enable: bool):
261262
def get_piecewise_cuda_graph_flag() -> bool:
262263
global _enable_piecewise_cuda_graph
263264
return _enable_piecewise_cuda_graph
265+
266+
267+
@contextlib.contextmanager
268+
def nvml_context() -> Generator[None, None, None]:
269+
"""Context manager for NVML initialization and shutdown.
270+
271+
Raises:
272+
RuntimeError: If NVML initialization fails
273+
"""
274+
try:
275+
pynvml.nvmlInit()
276+
yield
277+
except pynvml.NVMLError as e:
278+
raise RuntimeError(f"Failed to initialize NVML: {e}")
279+
finally:
280+
try:
281+
pynvml.nvmlShutdown()
282+
except:
283+
pass
284+
285+
def device_id_to_physical_device_id(device_id: int) -> int:
286+
"""Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES."""
287+
if "CUDA_VISIBLE_DEVICES" in os.environ:
288+
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
289+
try:
290+
physical_device_id = int(device_ids[device_id])
291+
return physical_device_id
292+
except ValueError:
293+
raise RuntimeError(
294+
f"Failed to convert logical device ID {device_id} to physical device ID. Available devices are: {device_ids}."
295+
)
296+
else:
297+
return device_id
298+
299+
def get_device_uuid(device_idx: int) -> str:
300+
"""Get the UUID of a CUDA device using NVML."""
301+
# Convert logical device index to physical device index
302+
303+
global_device_idx = device_id_to_physical_device_id(device_idx)
304+
305+
# Get the device handle and UUID
306+
with nvml_context():
307+
try:
308+
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
309+
uuid = pynvml.nvmlDeviceGetUUID(handle)
310+
# Ensure the UUID is returned as a string, not bytes
311+
if isinstance(uuid, bytes):
312+
return uuid.decode("utf-8")
313+
elif isinstance(uuid, str):
314+
return uuid
315+
else:
316+
raise RuntimeError(
317+
f"Unexpected UUID type: {type(uuid)} for device {device_idx} (global index: {global_device_idx})"
318+
)
319+
except pynvml.NVMLError as e:
320+
raise RuntimeError(
321+
f"Failed to get device UUID for device {device_idx} (global index: {global_device_idx}): {e}"
322+
)
323+
324+
def get_free_memory_bytes(device_idx: int) -> float:
325+
"""Get the free memory of a CUDA device in bytes using NVML."""
326+
global_device_idx = device_id_to_physical_device_id(device_idx)
327+
with nvml_context():
328+
try:
329+
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
330+
return pynvml.nvmlDeviceGetMemoryInfo(handle).free
331+
except pynvml.NVMLError as e:
332+
raise RuntimeError(
333+
f"Failed to get free memory for device {device_idx} (global index: {global_device_idx}): {e}"
334+
)

tensorrt_llm/executor/executor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,25 @@ def generate(
204204

205205
return futures
206206

207+
def async_update_weights_from_ipc_handles(self, handles: dict):
208+
update_weights_request = GenerationRequest([], SamplingParams(end_id=0))
209+
update_weights_request.set_weight_ipc_handles(handles)
210+
result = self.submit(update_weights_request)
211+
return result
212+
213+
def async_sleep(self, level: int = 1):
214+
sleep_request = GenerationRequest([], SamplingParams(end_id=0))
215+
sleep_request.set_sleep_level(level)
216+
result = self.submit(sleep_request)
217+
return result
218+
219+
def async_wakeup(self):
220+
sleep_request = GenerationRequest([], SamplingParams(end_id=0))
221+
sleep_request.set_wakeup_level(1)
222+
result = self.submit(sleep_request)
223+
return result
224+
225+
207226
def _get_next_client_id(self):
208227
# (self._last_client_id + 1) % UINT64_MAX
209228
self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1)

0 commit comments

Comments
 (0)