@@ -43,63 +43,6 @@ def exit_distributed():
4343 """Exit distributed training"""
4444 if dist .is_initialized ():
4545 dist .destroy_process_group ()
46-
47- @contextlib .contextmanager
48- def nvml_context () -> Generator [None , None , None ]:
49- """Context manager for NVML initialization and shutdown.
50-
51- Raises:
52- RuntimeError: If NVML initialization fails
53- """
54- try :
55- pynvml .nvmlInit ()
56- yield
57- except pynvml .NVMLError as e :
58- raise RuntimeError (f"Failed to initialize NVML: { e } " )
59- finally :
60- try :
61- pynvml .nvmlShutdown ()
62- except :
63- pass
64-
65- def device_id_to_physical_device_id (device_id : int ) -> int :
66- """Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES."""
67- if "CUDA_VISIBLE_DEVICES" in os .environ :
68- device_ids = os .environ ["CUDA_VISIBLE_DEVICES" ].split ("," )
69- try :
70- physical_device_id = int (device_ids [device_id ])
71- return physical_device_id
72- except ValueError :
73- raise RuntimeError (
74- f"Failed to convert logical device ID { device_id } to physical device ID. Available devices are: { device_ids } ."
75- )
76- else :
77- return device_id
78-
79- def get_device_uuid (device_idx : int ) -> str :
80- """Get the UUID of a CUDA device using NVML."""
81- # Convert logical device index to physical device index
82- global_device_idx = device_id_to_physical_device_id (device_idx )
83-
84- # Get the device handle and UUID
85- with nvml_context ():
86- try :
87- handle = pynvml .nvmlDeviceGetHandleByIndex (global_device_idx )
88- uuid = pynvml .nvmlDeviceGetUUID (handle )
89- # Ensure the UUID is returned as a string, not bytes
90- if isinstance (uuid , bytes ):
91- return uuid .decode ("utf-8" )
92- elif isinstance (uuid , str ):
93- return uuid
94- else :
95- raise RuntimeError (
96- f"Unexpected UUID type: { type (uuid )} for device { device_idx } (global index: { global_device_idx } )"
97- )
98- except pynvml .NVMLError as e :
99- raise RuntimeError (
100- f"Failed to get device UUID for device { device_idx } (global index: { global_device_idx } ): { e } "
101- )
102-
10346class fsdp_interface :
10447 def __init__ (self , model_dir ):
10548 self .model_dir = model_dir
@@ -163,7 +106,7 @@ def report_device_id(self) -> str:
163106 Returns:
164107 str: UUID of the device in the format "GPU-xxxxx"
165108 """
166-
109+ from tensorrt_llm . _torch . utils import get_device_uuid
167110 # Get current device index from torch
168111 device_idx = torch .cuda .current_device ()
169112 # Get device UUID using NVML
@@ -213,7 +156,8 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]:
213156 converted_params = {}
214157 for key in keys :
215158 # Get full_tensor for dtensor (GPU > 1)
216- print (f"key: { key } " )
159+ if not key .startswith ("model." ):
160+ continue
217161 tensor = self ._held_sharded_state_dict_reference [key ]
218162 if isinstance (tensor , DTensor ):
219163 full_tensor = tensor .full_tensor ()
@@ -256,7 +200,10 @@ def load_trtllm_model(self, model_dir, tensor_parallel_size):
256200 tensor_parallel_size = tensor_parallel_size ,
257201 #disable_overlap_scheduler=True,
258202 #load_format='auto'
259- #load_format='dummy'
203+ load_format = 'dummy' ,
204+ kv_cache_config = KvCacheConfig (
205+ enable_block_reuse = False
206+ )
260207 )
261208 else :
262209 return None
0 commit comments