Skip to content

Commit 3ad8422

Browse files
committed
enable_block_reuse=False; clean device uuid code
1 parent 909d641 commit 3ad8422

File tree

2 files changed

+12
-69
lines changed

2 files changed

+12
-69
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,18 +1102,13 @@ def _update_weight(self, update_weight_request):
11021102
weights = {}
11031103

11041104
# Process each handle to get the tensor
1105-
i = 0
11061105
for name, handle in handles:
11071106
func, args = handle
11081107
list_args = list(args)
11091108
# Update device ID to match the current device
11101109
list_args[6] = self.device_id
11111110
tensor = func(*list_args)
1112-
if i % 2 == 0:
1113-
weights[name] = tensor
1114-
else:
1115-
weights[name] = tensor # + 1.0
1116-
i += 1
1111+
weights[name] = tensor
11171112

11181113
# Load weights into the model
11191114
self.model_engine.model.load_weights(weights)
@@ -1123,10 +1118,11 @@ def _update_weight(self, update_weight_request):
11231118
self._enqueue_responses({update_weight_request.id: update_weight_response})
11241119
except Exception as e:
11251120
print(
1126-
f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}"
1121+
f"Error in update_weights_from_ipc_handles: {e}"
11271122
)
1128-
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=False), is_final=True), client_id=update_weight_request.id)
1129-
self._enqueue_responses({update_weight_request.id: update_weight_response})
1123+
raise e
1124+
#update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=False), is_final=True), client_id=update_weight_request.id)
1125+
#self._enqueue_responses({update_weight_request.id: update_weight_response})
11301126

11311127
def _executor_loop_overlap(self):
11321128
torch.cuda.set_device(self.device_id)

tests/unittest/llmapi/test_llm_update_weights.py

Lines changed: 7 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -43,63 +43,6 @@ def exit_distributed():
4343
"""Exit distributed training"""
4444
if dist.is_initialized():
4545
dist.destroy_process_group()
46-
47-
@contextlib.contextmanager
48-
def nvml_context() -> Generator[None, None, None]:
49-
"""Context manager for NVML initialization and shutdown.
50-
51-
Raises:
52-
RuntimeError: If NVML initialization fails
53-
"""
54-
try:
55-
pynvml.nvmlInit()
56-
yield
57-
except pynvml.NVMLError as e:
58-
raise RuntimeError(f"Failed to initialize NVML: {e}")
59-
finally:
60-
try:
61-
pynvml.nvmlShutdown()
62-
except:
63-
pass
64-
65-
def device_id_to_physical_device_id(device_id: int) -> int:
66-
"""Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES."""
67-
if "CUDA_VISIBLE_DEVICES" in os.environ:
68-
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
69-
try:
70-
physical_device_id = int(device_ids[device_id])
71-
return physical_device_id
72-
except ValueError:
73-
raise RuntimeError(
74-
f"Failed to convert logical device ID {device_id} to physical device ID. Available devices are: {device_ids}."
75-
)
76-
else:
77-
return device_id
78-
79-
def get_device_uuid(device_idx: int) -> str:
80-
"""Get the UUID of a CUDA device using NVML."""
81-
# Convert logical device index to physical device index
82-
global_device_idx = device_id_to_physical_device_id(device_idx)
83-
84-
# Get the device handle and UUID
85-
with nvml_context():
86-
try:
87-
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
88-
uuid = pynvml.nvmlDeviceGetUUID(handle)
89-
# Ensure the UUID is returned as a string, not bytes
90-
if isinstance(uuid, bytes):
91-
return uuid.decode("utf-8")
92-
elif isinstance(uuid, str):
93-
return uuid
94-
else:
95-
raise RuntimeError(
96-
f"Unexpected UUID type: {type(uuid)} for device {device_idx} (global index: {global_device_idx})"
97-
)
98-
except pynvml.NVMLError as e:
99-
raise RuntimeError(
100-
f"Failed to get device UUID for device {device_idx} (global index: {global_device_idx}): {e}"
101-
)
102-
10346
class fsdp_interface:
10447
def __init__(self, model_dir):
10548
self.model_dir = model_dir
@@ -163,7 +106,7 @@ def report_device_id(self) -> str:
163106
Returns:
164107
str: UUID of the device in the format "GPU-xxxxx"
165108
"""
166-
109+
from tensorrt_llm._torch.utils import get_device_uuid
167110
# Get current device index from torch
168111
device_idx = torch.cuda.current_device()
169112
# Get device UUID using NVML
@@ -213,7 +156,8 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]:
213156
converted_params = {}
214157
for key in keys:
215158
# Get full_tensor for dtensor (GPU > 1)
216-
print(f"key: {key}")
159+
if not key.startswith("model."):
160+
continue
217161
tensor = self._held_sharded_state_dict_reference[key]
218162
if isinstance(tensor, DTensor):
219163
full_tensor = tensor.full_tensor()
@@ -256,7 +200,10 @@ def load_trtllm_model(self, model_dir, tensor_parallel_size):
256200
tensor_parallel_size=tensor_parallel_size,
257201
#disable_overlap_scheduler=True,
258202
#load_format='auto'
259-
#load_format='dummy'
203+
load_format='dummy',
204+
kv_cache_config=KvCacheConfig(
205+
enable_block_reuse=False
206+
)
260207
)
261208
else:
262209
return None

0 commit comments

Comments
 (0)