diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index b863c327674..78f93a0b3c4 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -735,6 +735,12 @@ class WindowBlockManager return 0; } + void resetReuseState() + { + mContextBlocksByHash.clear(); + mCachedBlocksRoot = std::make_shared(KVCacheBlock::kCachedBlocksRootId, tensorrt_llm::kernels::KVCacheIndex{0}); + } + private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx); @@ -1120,6 +1126,13 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getPool(relativePoolIndex); } + void resetReuseState() + { + for (auto& [windowSize, manager] : mWindowBlockManagers) + { + manager.resetReuseState(); + } + } private: [[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const { @@ -1290,6 +1303,7 @@ class BaseKVCacheManager virtual void refreshBlocks() = 0; virtual void flushIterationEvents() = 0; + virtual void resetReuseState() = 0; [[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock); @@ -1633,6 +1647,11 @@ class KVCacheManager : public BaseKVCacheManager mBlockManager.flushIterationEvents(); } + void resetReuseState() override + { + mBlockManager.resetReuseState(); + } + /// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity. /// /// @param inputLength The number of input tokens in the sequence. diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index a75db66eaee..9d2af62ed3c 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -420,7 +420,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds) - .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents); + .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents) + .def("reset_reuse_state", &BaseKVCacheManager::resetReuseState); py::enum_(m, "CacheType") .value("SELF", tbk::CacheType::kSELF) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index a8ce31bf2ce..17dc332eb2f 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -724,6 +724,9 @@ def load_single_module(name, module): for new_name in params_map[names[-1]]: fw = filter_weights('.'.join(names[:-1] + [new_name]), weights) + # tmp fixes to enable partial updates in old path + if not fw: + continue if new_name in ['k_proj', 'v_proj']: num_kv_heads_list = [num_kv_heads ] * len(fw) if isinstance( @@ -740,15 +743,18 @@ def load_single_module(name, module): } module_weights.append(fw) - module.load_weights(weights=module_weights) + if module_weights: + module.load_weights(weights=module_weights) + else: module_weights = filter_weights(name, weights) - if hasattr(module, 'load_weights'): - module.load_weights(weights=[module_weights]) - else: - for n, p in module._parameters.items(): - if p is not None: - p.data.copy_(module_weights[n][:]) + if module_weights: + if hasattr(module, 'load_weights'): + module.load_weights(weights=[module_weights]) + else: + for n, p in module._parameters.items(): + if p is not None: + p.data.copy_(module_weights[n][:]) if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL", False) in ["True", "true", "1", "yes", "y"]: diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 461c5de941e..8a801b4862c 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -163,7 +163,8 @@ def __init__(self, return_log_probs: bool = False, return_context_logits: bool = False, return_generation_logits: bool = False, - exclude_last_generation_logits: bool = False): + exclude_last_generation_logits: bool = False, + success: bool = False): self._streaming = streaming self._context_logits = LogitsStorage( prompt_len, use_device_memory) if return_context_logits else None @@ -171,6 +172,7 @@ def __init__(self, max_new_tokens, use_device_memory, exclude_last_generation_logits ) if return_generation_logits else None self._log_probs = LogProbStorage() if return_log_probs else None + self._success = success def append_context_logits(self, context_logits: torch.Tensor): if self._context_logits: @@ -246,8 +248,9 @@ def __getattr__(self, item): return getattr(result, item) def deserialize(self): - self._result = tensorrt_llm.bindings.executor.deserialize_result( - self._result) + if self._result is not None: + self._result = tensorrt_llm.bindings.executor.deserialize_result( + self._result) @dataclass diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index d6e5b3e4da3..4a09c223a8a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1087,7 +1087,6 @@ def init_meta_tensor(t: torch.Tensor): weights = load_weights(model.llm_checkpoint_dir) else: weights = load_weights(checkpoint_dir) - model.load_weights(weights) if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 5a00d017a2d..95d5b937a19 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -17,6 +17,7 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager +from tensorrt_llm._torch.utils import get_device_uuid from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank, is_trace_enabled, nvtx_range, trace_func) from tensorrt_llm.bindings.executor import (DisServingRequestStats, @@ -33,7 +34,7 @@ from ..speculative.drafter import Drafter from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, - LlmResponse, executor_request_to_llm_request) + LlmResponse, LlmResult, executor_request_to_llm_request, PyResult) from .model_engine import ModelEngine from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler from .scheduler import RequestScheduler, ScheduledRequests @@ -51,6 +52,9 @@ PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" SHUTDOWN_REQUEST_ID = -1 +UPDATE_WEIGHT_REQUEST_ID = -2 +SLEEP_REQUEST_ID = -3 +WAKEUP_REQUEST_ID = -4 @dataclasses.dataclass @@ -59,6 +63,9 @@ class RequestQueueItem: request: Optional[ExecutorRequest] = None is_canceled_request: bool = False query: Optional[list] = None # only used in `StarAttention` + weight_ipc_handles: Optional[dict] = None + sleep_level: Optional[int] = None + wakeup_level: Optional[int] = None @property def is_shutdown_request(self): @@ -66,8 +73,15 @@ def is_shutdown_request(self): @property def is_normal_request(self): - return not (self.is_shutdown_request or self.is_canceled_request) + return self.id > 0 and not self.is_canceled_request + def is_update_weight_request(self): + return self.id == UPDATE_WEIGHT_REQUEST_ID + def is_sleep_request(self): + return self.id == SLEEP_REQUEST_ID + + def is_wakeup_request(self): + return self.id == WAKEUP_REQUEST_ID def _get_from_request_queue( request_queue, @@ -244,6 +258,7 @@ def __init__(self, self.num_fetch_requests_cur_rank = 0 self.num_fetch_requests = 0 self.shutdown_event = threading.Event() + self.request_accumulator: List[RequestQueueItem] = [] # response used data self.response_lock = threading.Lock() @@ -287,6 +302,8 @@ def __init__(self, self.draft_model_engine.warmup(self.resource_manager) self.is_shutdown = False + self.is_control_request = False + self.control_request_id = 0 self.stats_lock = threading.Lock() self.stats = [] @@ -465,7 +482,10 @@ def wait_shutdown(self): def enqueue_request(self, request: ExecutorRequest, - query: Optional[List] = None): + query: Optional[List] = None, + weight_ipc_handles: Optional[dict] = None, + sleep_level: Optional[int] = None, + wakeup_level: Optional[int] = None): """ Enqueue a new request, query is only used in `StarAttention`. """ @@ -476,10 +496,17 @@ def enqueue_request(self, if self.enable_iter_perf_stats: self.start_times[req_id] = time.time() - if query is not None: + if weight_ipc_handles is not None: + self.request_queue.put(RequestQueueItem(UPDATE_WEIGHT_REQUEST_ID, None, False, None, weight_ipc_handles)) + elif sleep_level is not None: + self.request_queue.put(RequestQueueItem(SLEEP_REQUEST_ID, None, False, None, None, sleep_level)) + elif wakeup_level is not None: + self.request_queue.put(RequestQueueItem(WAKEUP_REQUEST_ID, None, False, None, None, None, wakeup_level)) + elif query is not None: self.request_queue.put(RequestQueueItem(req_id, request, query)) else: self.request_queue.put(RequestQueueItem(req_id, request)) + #self.request_queue.put(RequestQueueItem(req_id, request, False, query, weight_ipc_handles, sleep_level, wakeup_level)) self.next_req_id += 1 finally: self.enqueue_lock.release() @@ -756,6 +783,18 @@ def _executor_loop_pp(self): new_requests = self._fetch_new_requests() if self.should_stop_processing: break + if self.is_control_request: + self.is_control_request = False + assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}" + if (new_requests[0].is_update_weight_request()): + self._update_weight(new_requests[0]) + elif (new_requests[0].is_sleep_request()): + self._sleep(new_requests[0]) + elif (new_requests[0].is_wakeup_request()): + self._wakeup(new_requests[0]) + else: + assert False, "Invalid control request" + continue if self.enable_iter_perf_stats: iter_stats = self._get_init_iter_stats( @@ -907,6 +946,18 @@ def _executor_loop(self): new_requests = self._fetch_new_requests() if self.should_stop_processing: break + if self.is_control_request: + self.is_control_request = False + assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}" + if (new_requests[0].is_update_weight_request()): + self._update_weight(new_requests[0]) + elif (new_requests[0].is_sleep_request()): + self._sleep(new_requests[0]) + elif (new_requests[0].is_wakeup_request()): + self._wakeup(new_requests[0]) + else: + assert False, "Invalid control request" + continue if self.kv_cache_transceiver: self._check_disagg_gen_transfer_status() @@ -1033,6 +1084,70 @@ def _prepare_draft_requests(self): logger.error(f"Encountered an error in decode: {error_msg}") self._handle_errors(error_msg) + def reset_prefix_cache(self): + self.kv_cache_manager.reset_reuse_state() + + def update_weights(self, weights): + # Load weights into the model + self.model_engine.model.load_weights(weights) + torch.cuda.synchronize() + + self.reset_prefix_cache() + + def update_weight_from_ipc_handles(self, handles): + """ + Update model weights from IPC handles. + + Args: + ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + {device_uuid: all_handles} + """ + from tensorrt_llm._torch.utils import get_device_uuid + device_uuid = get_device_uuid(self.device_id) + + if device_uuid not in handles: + raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles") + + try: + weights = {} + all_handles = handles[device_uuid] + + for param_name, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = self.device_id # Set target device + tensor = func(*list_args) + weights[param_name] = tensor + + self.update_weights(weights) + + except Exception as e: + logger.error(f"failed to update weights from ipc handles: {e}") + return False + + def _sleep(self, sleep_request): + self.is_sleep_request = False + self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id)}) + + def _wakeup(self, wakeup_request): + self.is_wakeup_request = False + self._enqueue_responses({wakeup_request.id: LlmResponse(request_id=wakeup_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=wakeup_request.id)}) + + def _update_weight(self, update_weight_request): + self.is_update_weight_request = False + + try: + self.update_weight_from_ipc_handles(update_weight_request.weight_ipc_handles) + update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=update_weight_request.id) + self._enqueue_responses({update_weight_request.id: update_weight_response}) + except Exception as e: + print( + f"Error in update_weights_from_ipc_handles: {e}" + ) + raise e + #update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=False), is_final=True), client_id=update_weight_request.id) + #self._enqueue_responses({update_weight_request.id: update_weight_response}) + def _executor_loop_overlap(self): torch.cuda.set_device(self.device_id) if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver: @@ -1052,6 +1167,18 @@ def _executor_loop_overlap(self): new_requests = self._fetch_new_requests() if self.should_stop_processing: break + if self.is_control_request: + self.is_control_request = False + assert len(new_requests) == 1, f"control request should be the only request in the list, but got {len(new_requests)}" + if (new_requests[0].is_update_weight_request()): + self._update_weight(new_requests[0]) + elif (new_requests[0].is_sleep_request()): + self._sleep(new_requests[0]) + elif (new_requests[0].is_wakeup_request()): + self._wakeup(new_requests[0]) + else: + assert False, "Invalid control request" + continue if self.kv_cache_transceiver: self._check_disagg_gen_transfer_status() @@ -1263,20 +1390,43 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: new_requests, py_request_objects = self._broadcast_new_requests( new_requests, py_request_objects) + self.request_accumulator.extend(new_requests) + # drop requests arriving after shutdown valid_new_requests = [] - for req_item in new_requests: + find_control_request = False + for i, req_item in enumerate(self.request_accumulator): if req_item.is_shutdown_request: self.is_shutdown = True + find_control_request = True + break + if req_item.is_update_weight_request() or req_item.is_sleep_request() or req_item.is_wakeup_request(): + find_control_request = True + self.control_request_id = req_item.id break elif req_item.is_canceled_request: self.canceled_req_ids.append(req_item.id) + + if (find_control_request): + if (i==0): + if not self.is_shutdown: + valid_new_requests = self.request_accumulator[:1] + self.is_control_request = True + self.request_accumulator = self.request_accumulator[1:] + return valid_new_requests else: - valid_new_requests.append(req_item) + valid_new_requests = self.request_accumulator[:i] + self.request_accumulator = self.request_accumulator[i:] + else: + valid_new_requests = self.request_accumulator + self.request_accumulator = [] + # Check if the beam width of the requests is equal to the max_beam_width for req_item in valid_new_requests: assert req_item.request.sampling_config.beam_width == self.max_beam_width, f"Request beam width {req_item.request.sampling_config.beam_width} is not equal to max_beam_width {self.max_beam_width}. This is not supported!" + new_requests = valid_new_requests + if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp) and self.dist.rank > 0: for attr_name, req_obj_dict in py_request_objects: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index e52096727c6..d1d22b36c38 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -568,6 +568,10 @@ def get_kv_cache_stats(self): def rewind_kv_cache(self, request: LlmRequest, rewind_len: int): self.impl.rewind_kv_cache(request.py_request_id, rewind_len) + def reset_reuse_state(self): + """Reset the reuse state of the KV cache manager.""" + self.impl.reset_reuse_state() + def _get_window_size_to_layers(self) -> dict[int, list[int]]: """ Get the window size to layers mapping. diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index f687e9d9f55..21514d1945e 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -3,9 +3,10 @@ import threading from dataclasses import dataclass from enum import Enum -from typing import Dict, List +from typing import Dict, List, Generator import torch +import pynvml from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor from tensorrt_llm.math_utils import ceil_div, pad_up @@ -259,3 +260,73 @@ def set_piecewise_cuda_graph_flag(enable: bool): def get_piecewise_cuda_graph_flag() -> bool: global _enable_piecewise_cuda_graph return _enable_piecewise_cuda_graph + + +@contextlib.contextmanager +def nvml_context() -> Generator[None, None, None]: + """Context manager for NVML initialization and shutdown. + + Raises: + RuntimeError: If NVML initialization fails + """ + try: + pynvml.nvmlInit() + yield + except pynvml.NVMLError as e: + raise RuntimeError(f"Failed to initialize NVML: {e}") + finally: + try: + pynvml.nvmlShutdown() + except: + pass + +def device_id_to_physical_device_id(device_id: int) -> int: + """Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES.""" + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + try: + physical_device_id = int(device_ids[device_id]) + return physical_device_id + except ValueError: + raise RuntimeError( + f"Failed to convert logical device ID {device_id} to physical device ID. Available devices are: {device_ids}." + ) + else: + return device_id + +def get_device_uuid(device_idx: int) -> str: + """Get the UUID of a CUDA device using NVML.""" + # Convert logical device index to physical device index + + global_device_idx = device_id_to_physical_device_id(device_idx) + + # Get the device handle and UUID + with nvml_context(): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx) + uuid = pynvml.nvmlDeviceGetUUID(handle) + # Ensure the UUID is returned as a string, not bytes + if isinstance(uuid, bytes): + return uuid.decode("utf-8") + elif isinstance(uuid, str): + return uuid + else: + raise RuntimeError( + f"Unexpected UUID type: {type(uuid)} for device {device_idx} (global index: {global_device_idx})" + ) + except pynvml.NVMLError as e: + raise RuntimeError( + f"Failed to get device UUID for device {device_idx} (global index: {global_device_idx}): {e}" + ) + +def get_free_memory_bytes(device_idx: int) -> float: + """Get the free memory of a CUDA device in bytes using NVML.""" + global_device_idx = device_id_to_physical_device_id(device_idx) + with nvml_context(): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx) + return pynvml.nvmlDeviceGetMemoryInfo(handle).free + except pynvml.NVMLError as e: + raise RuntimeError( + f"Failed to get free memory for device {device_idx} (global index: {global_device_idx}): {e}" + ) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 2e84d9abc44..5493140f28f 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -201,6 +201,25 @@ def generate( return futures + def async_update_weights_from_ipc_handles(self, handles: dict): + update_weights_request = GenerationRequest([], SamplingParams(end_id=0)) + update_weights_request.set_weight_ipc_handles(handles) + result = self.submit(update_weights_request) + return result + + def async_sleep(self, level: int = 1): + sleep_request = GenerationRequest([], SamplingParams(end_id=0)) + sleep_request.set_sleep_level(level) + result = self.submit(sleep_request) + return result + + def async_wakeup(self): + sleep_request = GenerationRequest([], SamplingParams(end_id=0)) + sleep_request.set_wakeup_level(1) + result = self.submit(sleep_request) + return result + + def _get_next_client_id(self): # (self._last_client_id + 1) % UINT64_MAX self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1) diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index 886831d0723..352a4b9a298 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -110,12 +110,53 @@ def __init__( self.kv_cache_retention_config = kv_cache_retention_config self.id: Optional[int] = None self.disaggregated_params = disaggregated_params + self.weight_ipc_handles: Optional[dict] = None + self.sleep_level: Optional[int] = None + self.wakeup_level: Optional[int] = None def set_id(self, id): - assert self.id is None, f"Request ID is already set: {self.id}" - self.id = id + if self.prompt_token_ids != []: + assert self.id is None, f"Request ID is already set: {self.id}" + self.id = id + return self + + def set_weight_ipc_handles(self, handles: dict): + assert self.prompt_token_ids == [], "Prompt token ids must be empty for weight update request" + self.id = -2 + self.weight_ipc_handles: dict = handles + return self + + def set_sleep_level(self, level: int): + assert self.prompt_token_ids == [], "Prompt token ids must be empty for sleep request" + self.id = -3 + self.sleep_level = level + return self + + def set_wakeup_level(self, level: int): + assert self.prompt_token_ids == [], "Prompt token ids must be empty for wakeup request" + self.id = -4 + self.wakeup_level = level return self + def is_shutdown_request(self) -> bool: + return self.id == -1 + + def is_weight_update_request(self) -> bool: + return self.id == -2 + + def is_sleep_request(self) -> bool: + return self.id == -3 + + def is_wakeup_request(self) -> bool: + return self.id == -4 + + def is_normal_request(self) -> bool: + return self.id > 0 + + + + def debug_print(self, message: str = ""): + print(f"GenerationRequest {self.id} {message}") class CancellingRequest: ''' The request to cancel a generation. ''' diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 9cd539f33b3..f4b1b872af2 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -16,7 +16,7 @@ from ..llmapi.tracer import global_tracer from ..llmapi.utils import AsyncQueue from ..sampling_params import LogprobParams, SamplingParams -from .utils import ErrorResponse, has_event_loop, is_llm_response +from .utils import ErrorResponse, has_event_loop, is_llm_response, is_update_weights_response, is_sleep_response, is_wakeup_response if TYPE_CHECKING: from .executor import GenerationExecutor @@ -146,6 +146,7 @@ def __init__(self, self.disaggregated_params = None self.decoding_iter = 0 self._done = False + self._success = False if has_event_loop(): self.aqueue = AsyncQueue() @@ -303,6 +304,7 @@ def _handle_response(self, response_result.deserialize() self._done = response_result.is_final + self._success = True # TODO: replace with response_result._py_result._success context_phase_params = response_result.context_phase_params self.decoding_iter = response_result.decoding_iter if context_phase_params is not None: @@ -331,6 +333,15 @@ def _handle_response(self, if self._background_error_handler and ( handler := self._background_error_handler()): handler() + elif is_update_weights_response(response): + self._success = response.result._py_result._success + self._done = True + elif is_sleep_response(response): + self._success = response.result._py_result._success + self._done = True + elif is_wakeup_response(response): + self._success = response.result._py_result._success + self._done = True elif isinstance(response, ErrorResponse): if self._background_error_handler is not None and ( handler := self._background_error_handler()): @@ -457,6 +468,10 @@ def aborted(self) -> bool: def finished(self) -> bool: return self._done + @property + def success(self) -> bool: + return self._success + def clear_logprob_params(self) -> None: # Remove temporary attribute used in executor # for a cleaner external-facing output. @@ -527,7 +542,7 @@ def _exception(self, timeout: Optional[float] = None): def _repr_fields(self): return [ - 'request_id', 'prompt_token_ids', 'outputs', 'finished', + 'request_id', 'prompt_token_ids', 'outputs', 'finished', 'success', "context_logits" ] diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index 0a046e0e50f..b21f8c98da3 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -147,8 +147,16 @@ class WorkerCommIpcAddrs(NamedTuple): def is_llm_response(instance): - return hasattr(instance, "result") + return hasattr(instance, "result") and hasattr(instance, "request_id") and instance.request_id > 0 +def is_update_weights_response(instance): + return hasattr(instance, "result") and hasattr(instance, "request_id") and instance.request_id == -2 + +def is_sleep_response(instance): + return hasattr(instance, "result") and hasattr(instance, "request_id") and instance.request_id == -3 + +def is_wakeup_response(instance): + return hasattr(instance, "result") and hasattr(instance, "request_id") and instance.request_id == -4 def print_alive_threads(): assert enable_llm_debug( diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index da90fc8fe93..101a79534ab 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -38,7 +38,7 @@ from .result import (GenerationResult, IterationResult, LogProbsResult, ResponseWrapper, compute_logprobs) from .utils import (ErrorResponse, IntraProcessQueue, RequestError, - WorkerCommIpcAddrs, has_event_loop, is_llm_response) + WorkerCommIpcAddrs, has_event_loop, is_llm_response, is_update_weights_response, is_sleep_response, is_wakeup_response) __all__ = [ "GenerationExecutorWorker", @@ -443,7 +443,15 @@ def _deduce_max_tokens(request: GenerationRequest, f"prompt length {splited_prompt_len} plus query length {query_token_len} " f"is larger than max_seq_len {executor_config.max_seq_len}") return default_max_tokens - + if request.is_weight_update_request(): + req_id = self.engine.enqueue_request(request, weight_ipc_handles=request.weight_ipc_handles) + return req_id + elif request.is_sleep_request(): + req_id = self.engine.enqueue_request(request, sleep_level=request.sleep_level) + return req_id + elif request.is_wakeup_request(): + req_id = self.engine.enqueue_request(request, wakeup_level=request.wakeup_level) + return req_id try: executor_request = tllm.Request( client_id=request.id, @@ -492,7 +500,6 @@ def _deduce_max_tokens(request: GenerationRequest, lp = request.sampling_params.logits_processor executor_request.py_logits_post_processors = lp if isinstance( lp, list) else [lp] - if request.query_token_ids is not None: # pytorch star attention workflow # a workaround to avoid public interface update @@ -1008,6 +1015,12 @@ def _send_rsp( if is_llm_response(response): if response.has_error() or response.result.is_final: worker._pop_result(response.client_id) + elif is_update_weights_response(response): + worker._pop_result(response.client_id) + elif is_sleep_response(response): + worker._pop_result(response.client_id) + elif is_wakeup_response(response): + worker._pop_result(response.client_id) elif isinstance(response, ErrorResponse): worker._pop_result(response.client_id) else: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 9cd606e3227..69c5bdf6ab2 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -603,6 +603,33 @@ def _build_model(self): self.llm_build_stats)) self._engine_dir, self._hf_model_dir = model_loader() + def update_weights_from_ipc_handles_async(self, handles: dict): + result = self._executor.async_update_weights_from_ipc_handles(handles) + return result + + def update_weights_from_ipc_handles(self, handles: dict): + result = self.update_weights_from_ipc_handles_async(handles) + result.result() + return result + + def sleep_async(self, level: int = 1): + result = self._executor.async_sleep(level) + return result + + def sleep(self, level: int): + result = self.sleep_async(level) + result.result() + return result + + def wakeup_async(self): + result = self._executor.async_wakeup() + return result + + def wakeup(self): + result = self.wakeup_async() + result.result() + return result + @property def _on_trt_backend(self) -> bool: return isinstance(self.args, TrtLlmArgs) diff --git a/tests/unittest/llmapi/test_llm_update_weights.py b/tests/unittest/llmapi/test_llm_update_weights.py new file mode 100644 index 00000000000..0bb883ea22b --- /dev/null +++ b/tests/unittest/llmapi/test_llm_update_weights.py @@ -0,0 +1,363 @@ +import argparse +import torch +import torch.distributed as dist +import atexit +import os +from typing import Any, Optional +from tensorrt_llm import SamplingParams +from tensorrt_llm import LLM +from tensorrt_llm.llmapi.llm_args import KvCacheConfig +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + MixedPrecision, + ShardedStateDictConfig, + FullStateDictConfig +) +#from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def init_distributed(): + """Initialize distributed training""" + if "LOCAL_RANK" not in os.environ: + return 1, 0, torch.device("cuda:0") + + # Set default environment variables if not already set + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" + + dist.init_process_group(backend="nccl") + world_size = dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank) + return world_size, rank, torch.device(f"cuda:{rank}") + +def exit_distributed(): + """Exit distributed training""" + if dist.is_initialized(): + dist.destroy_process_group() +class fsdp_interface: + def __init__(self, model_dir): + self.model_dir = model_dir + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.device = torch.device(f"cuda:{self.rank}") + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.model = self.load_fsdp_model(model_dir) + + def load_fsdp_model(self, model_dir): + """Load and initialize FSDP model""" + # Initialize distributed setup + print(f"World size: {self.world_size}, Rank: {self.rank}, Device: {self.device}") + + # Setup mixed precision policy for FSDP + mixed_precision_policy = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32 + ) + + if self.rank == 0: + print(f"Loading FSDP model from {model_dir}") + + # Initialize FSDP model + fsdp_model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map=self.device + ) + + # Print model info + if self.rank == 0: + total_params = sum(p.numel() for p in fsdp_model.parameters()) + trainable_params = sum(p.numel() for p in fsdp_model.parameters() if p.requires_grad) + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,}") + print(f"Model device: {next(fsdp_model.parameters()).device}") + + # Wrap model with FSDP + fsdp_model = FSDP( + fsdp_model, + mixed_precision=mixed_precision_policy, + device_id=torch.cuda.current_device(), + use_orig_params=True + ) + + if self.rank == 0: + print("FSDP model initialized successfully") + + self._held_streamed_param_reference = None + self._held_sharded_state_dict_reference = None + + return fsdp_model + + + def report_device_id(self) -> str: + """Report the UUID of the current CUDA device using NVML. + + Returns: + str: UUID of the device in the format "GPU-xxxxx" + """ + from tensorrt_llm._torch.utils import get_device_uuid + # Get current device index from torch + device_idx = torch.cuda.current_device() + # Get device UUID using NVML + return get_device_uuid(device_idx) + + @torch.no_grad() + def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: + # If the model is not FSDP, then we need to manually move it to the GPU + # For an FSDP model, model.state_dict() will move the params to the GPU + if not isinstance(self.model, FSDP): + self.model = self.manual_load_to_gpu(self.model) + self._held_sharded_state_dict_reference = self.model.state_dict() + else: + # Get sharded state dict instead of full state dict for FSDP1 + with FSDP.state_dict_type( + self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig() + ): + self._held_sharded_state_dict_reference = self.model.state_dict() + + # Collect info for streaming multiple tensors + ### state_dict_info = [] + ### for name, tensor in self._held_sharded_state_dict_reference.items(): + ### # dtensor's numel will return complete tensor instead of only local tensor + ### size_in_bytes = tensor.element_size() * tensor.numel() + ### state_dict_info.append((name, size_in_bytes)) + self.refit_param_info = [] + for name, tensor in self._held_sharded_state_dict_reference.items(): + # dtensor's numel will return complete tensor instead of only local tensor + size_in_bytes = tensor.element_size() * tensor.numel() + self.refit_param_info.append((name, size_in_bytes)) + + from tensorrt_llm._torch.utils import get_free_memory_bytes + #print(f"State dict info: {state_dict_info}") + # Collect current available memory for refit + ## Get current device index from torch + device_idx = torch.cuda.current_device() + ## Get device free memory using NVML + total_available_bytes = get_free_memory_bytes(device_idx) + ## Use 80% of the free memory for safety + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") + total_available_bytes *= float(memory_ratio) + + return self.refit_param_info, total_available_bytes + + @torch.no_grad() + def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: + from torch.distributed.tensor import DTensor + from torch.multiprocessing.reductions import reduce_tensor + + assert self._held_sharded_state_dict_reference is not None, ( + "prepare_weights_for_ipc must be called before get_weights_ipc_handles" + ) + + # Clean up the held tensors to reduce peak memory + if self._held_streamed_param_reference is not None: + del self._held_streamed_param_reference + self._held_streamed_param_reference = None + + converted_params = {} + for key in keys: + # Get full_tensor for dtensor (GPU > 1) + if not key.startswith("model."): + continue + tensor = self._held_sharded_state_dict_reference[key] + if isinstance(tensor, DTensor): + full_tensor = tensor.full_tensor() + else: + full_tensor = tensor + # Convert parameters to the configured dtype + #print(f"FSDP rank {self.rank} name: {key}, shape: {full_tensor.shape}, {full_tensor[0]}") + converted_params[key] = full_tensor + + # Temporary record the full tensor for cleanup + # It is needed for cleanup the last full_tensor in the refit process + self._held_streamed_param_reference = converted_params + + # Get device UUID for IPC + device_uuid = self.report_device_id() + # Create handles for the tensors + all_handles = [] + for key, p in converted_params.items(): + handle = reduce_tensor(p.detach()) + all_handles.append((key, handle)) + + #print(f"device_uuid: {device_uuid}, All handles keys: {[key for key, _ in all_handles]}") + print(f"device_uuid: {device_uuid}") + return {device_uuid: all_handles} + + @torch.no_grad() + def prepare_weights_for_ipc_refit( + self, _refit_buffer_size_gb: Optional[int] = None + ) -> list[list[str]]: + """Prepare the weights for IPC. + + Returns: + list: A list containing the keys of the parameters, which is grouped by size. + """ + # Get the state_dict_info and available memory from all workers + state_dict_info = self.refit_param_info + + if _refit_buffer_size_gb is not None: + total_available_bytes = _refit_buffer_size_gb * (1024**3) + else: + # Get the minimum available memory from all workers + total_available_bytes = min(result[1] for result in state_dict_info) + + # Group tensors by size + cur_available_bytes = total_available_bytes + grouped_param_keys: list[list[str]] = [] + keys: list[str] = [] + + for key, size_in_bytes in state_dict_info: + if size_in_bytes > cur_available_bytes: + if keys: + grouped_param_keys.append(keys) + keys = [] + cur_available_bytes = total_available_bytes + + keys.append(key) + cur_available_bytes -= size_in_bytes + + if keys: + grouped_param_keys.append(keys) + + return grouped_param_keys + +class trtllm_interface: + def __init__(self, model_dir, tensor_parallel_size): + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.device = torch.device(f"cuda:{self.rank}") + self.model_dir = model_dir + self.tensor_parallel_size = tensor_parallel_size + self.llm = self.load_trtllm_model(model_dir, tensor_parallel_size) + + def load_trtllm_model(self, model_dir, tensor_parallel_size): + if self.rank == 0: + print("Loading TensorRT-LLM model") + return LLM( + model=model_dir, + tensor_parallel_size=tensor_parallel_size, + #disable_overlap_scheduler=True, + #load_format='auto' + load_format='dummy', + kv_cache_config=KvCacheConfig( + free_gpu_memory_fraction=0.85, + # enable_block_reuse=False + ) + ) + else: + return None + +def cleanup(): + """Cleanup function to destroy process group""" + if dist.is_initialized(): + print(f"Cleaning up process group on rank {dist.get_rank()}") + dist.destroy_process_group() + + +def main(): + parser = argparse.ArgumentParser( + description="LLM models with the PyTorch workflow.") + + parser.add_argument('--model_dir', + type=str, + required=True, + default='/model/Qwen2.5-0.5B-Instruct', + help="Model checkpoint directory.") + + parser.add_argument('--tensor_parallel_size', + type=int, + default=2, + help="Tensor parallel size (number of GPUs to use)") + + parser.add_argument('--use_fsdp', + action='store_true', + help="Use FSDP model loading instead of direct TensorRT-LLM loading") + + args = parser.parse_args() + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + world_size, rank, device = init_distributed() + + sampling_params = SamplingParams(max_tokens=32) + + # Load FSDP model + fsdp = fsdp_interface(args.model_dir) + trtllm = trtllm_interface(args.model_dir, args.tensor_parallel_size) + + if rank == 0: + print(f"Collected handles from all {world_size} ranks:") + + # For FSDP mode, we would need additional logic to integrate withTensorRT-LLM + # This is a placeholder for now + if rank == 0: + + outputs = trtllm.llm.generate(prompts, sampling_params) + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") + + ## load the model from fsdp + ## then generate the output again + result = trtllm.llm.sleep(1) + print(f"sleep result: {result}") + + result = trtllm.llm.wakeup() + print(f"wakeup result: {result}") + + dict_info, total_available_bytes = fsdp.prepare_weights_for_ipc() + + grouped_param_keys = fsdp.prepare_weights_for_ipc_refit(0.5) + total_num_keys = sum(len(k) for k in grouped_param_keys) + print( + f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups" + ) + + from tensorrt_llm._torch.utils import get_free_memory_bytes + for keys in grouped_param_keys: + handles = fsdp.get_weights_ipc_handles(keys) + #print(f"handles: {handles}") + + # Collect handles from all ranks + all_handles = [None for _ in range(world_size)] + dist.all_gather_object(all_handles, handles) + all_handles = {k: v for d in all_handles for k, v in d.items()} + #print(f"all_handles: {all_handles.keys()}") + + device_idx = torch.cuda.current_device() + total_available_bytes = get_free_memory_bytes(device_idx) + print(f"total_available_bytes: {total_available_bytes}") + + if rank == 0: + result = trtllm.llm.update_weights_from_ipc_handles(all_handles) + print(f"update weights result: {result}") + + if rank == 0: + outputs = trtllm.llm.generate(prompts, sampling_params) + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") + + exit_distributed() +if __name__ == '__main__': + main() + +# torchrun --nproc_per_node=2 tests/unittest/llmapi/test_llm_update_weights.py --model_dir /model/Qwen2.5-0.5B-Instruct --tensor_parallel_size 2 +# torchrun --nproc_per_node=2 tests/unittest/llmapi/test_llm_update_weights.py --model_dir /model/Qwen2.5-3B-Instruct/ --tensor_parallel_size 2 \ No newline at end of file