33import torch .distributed as dist
44import atexit
55import os
6- from typing import Any
6+ from typing import Any , Optional
77from tensorrt_llm import SamplingParams
88from tensorrt_llm import LLM
99from tensorrt_llm .llmapi .llm_args import KvCacheConfig
1717#from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
1818from transformers import AutoModelForCausalLM , AutoTokenizer
1919
20- import contextlib
21- from typing import Generator
22- import pynvml
23-
2420
2521def init_distributed ():
2622 """Initialize distributed training"""
@@ -129,15 +125,29 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
129125 self ._held_sharded_state_dict_reference = self .model .state_dict ()
130126
131127 # Collect info for streaming multiple tensors
132- state_dict_info = []
128+ ### state_dict_info = []
129+ ### for name, tensor in self._held_sharded_state_dict_reference.items():
130+ ### # dtensor's numel will return complete tensor instead of only local tensor
131+ ### size_in_bytes = tensor.element_size() * tensor.numel()
132+ ### state_dict_info.append((name, size_in_bytes))
133+ self .refit_param_info = []
133134 for name , tensor in self ._held_sharded_state_dict_reference .items ():
134135 # dtensor's numel will return complete tensor instead of only local tensor
135136 size_in_bytes = tensor .element_size () * tensor .numel ()
136- state_dict_info .append ((name , size_in_bytes ))
137+ self . refit_param_info .append ((name , size_in_bytes ))
137138
139+ from tensorrt_llm ._torch .utils import get_free_memory_bytes
138140 #print(f"State dict info: {state_dict_info}")
141+ # Collect current available memory for refit
142+ ## Get current device index from torch
143+ device_idx = torch .cuda .current_device ()
144+ ## Get device free memory using NVML
145+ total_available_bytes = get_free_memory_bytes (device_idx )
146+ ## Use 80% of the free memory for safety
147+ memory_ratio = os .getenv ("NRL_REFIT_BUFFER_MEMORY_RATIO" , "0.8" )
148+ total_available_bytes *= float (memory_ratio )
139149
140- return state_dict_info
150+ return self . refit_param_info , total_available_bytes
141151
142152 @torch .no_grad ()
143153 def get_weights_ipc_handles (self , keys : list [str ]) -> dict [str , Any ]:
@@ -183,6 +193,44 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]:
183193 print (f"device_uuid: { device_uuid } " )
184194 return {device_uuid : all_handles }
185195
196+ @torch .no_grad ()
197+ def prepare_weights_for_ipc_refit (
198+ self , _refit_buffer_size_gb : Optional [int ] = None
199+ ) -> list [list [str ]]:
200+ """Prepare the weights for IPC.
201+
202+ Returns:
203+ list: A list containing the keys of the parameters, which is grouped by size.
204+ """
205+ # Get the state_dict_info and available memory from all workers
206+ state_dict_info = self .refit_param_info
207+
208+ if _refit_buffer_size_gb is not None :
209+ total_available_bytes = _refit_buffer_size_gb * (1024 ** 3 )
210+ else :
211+ # Get the minimum available memory from all workers
212+ total_available_bytes = min (result [1 ] for result in state_dict_info )
213+
214+ # Group tensors by size
215+ cur_available_bytes = total_available_bytes
216+ grouped_param_keys : list [list [str ]] = []
217+ keys : list [str ] = []
218+
219+ for key , size_in_bytes in state_dict_info :
220+ if size_in_bytes > cur_available_bytes :
221+ if keys :
222+ grouped_param_keys .append (keys )
223+ keys = []
224+ cur_available_bytes = total_available_bytes
225+
226+ keys .append (key )
227+ cur_available_bytes -= size_in_bytes
228+
229+ if keys :
230+ grouped_param_keys .append (keys )
231+
232+ return grouped_param_keys
233+
186234class trtllm_interface :
187235 def __init__ (self , model_dir , tensor_parallel_size ):
188236 self .world_size = dist .get_world_size ()
@@ -202,6 +250,7 @@ def load_trtllm_model(self, model_dir, tensor_parallel_size):
202250 #load_format='auto'
203251 load_format = 'dummy' ,
204252 kv_cache_config = KvCacheConfig (
253+ free_gpu_memory_fraction = 0.85 ,
205254 enable_block_reuse = False
206255 )
207256 )
@@ -251,23 +300,9 @@ def main():
251300 fsdp = fsdp_interface (args .model_dir )
252301 trtllm = trtllm_interface (args .model_dir , args .tensor_parallel_size )
253302
254- grouped_param_keys = [key for key ,size in fsdp .prepare_weights_for_ipc ()]
255- handles = fsdp .get_weights_ipc_handles (grouped_param_keys )
256- #print(f"handles: {handles}")
257-
258- # Collect handles from all ranks
259- all_handles = [None for _ in range (world_size )]
260- dist .all_gather_object (all_handles , handles )
261- all_handles = {k : v for d in all_handles for k , v in d .items ()}
262- print (f"all_handles: { all_handles .keys ()} " )
263-
264303 if rank == 0 :
265304 print (f"Collected handles from all { world_size } ranks:" )
266305
267- # Now all_handles contains the handles from each rank
268- # all_handles[0] = handles from rank 0
269- # all_handles[1] = handles from rank 1, etc.
270-
271306 # For FSDP mode, we would need additional logic to integrate withTensorRT-LLM
272307 # This is a placeholder for now
273308 if rank == 0 :
@@ -286,9 +321,34 @@ def main():
286321 result = trtllm .llm .wakeup ()
287322 print (f"wakeup result: { result } " )
288323
289- result = trtllm .llm .update_weights_from_ipc_handles (all_handles )
290- print (f"update weights result: { result } " )
324+ dict_info , total_available_bytes = fsdp .prepare_weights_for_ipc ()
325+
326+ grouped_param_keys = fsdp .prepare_weights_for_ipc_refit (0.5 )
327+ total_num_keys = sum (len (k ) for k in grouped_param_keys )
328+ print (
329+ f"[Refit] Split { total_num_keys } keys into { len (grouped_param_keys )} groups"
330+ )
291331
332+ from tensorrt_llm ._torch .utils import get_free_memory_bytes
333+ for keys in grouped_param_keys :
334+ handles = fsdp .get_weights_ipc_handles (keys )
335+ #print(f"handles: {handles}")
336+
337+ # Collect handles from all ranks
338+ all_handles = [None for _ in range (world_size )]
339+ dist .all_gather_object (all_handles , handles )
340+ all_handles = {k : v for d in all_handles for k , v in d .items ()}
341+ #print(f"all_handles: {all_handles.keys()}")
342+
343+ device_idx = torch .cuda .current_device ()
344+ total_available_bytes = get_free_memory_bytes (device_idx )
345+ print (f"total_available_bytes: { total_available_bytes } " )
346+
347+ if rank == 0 :
348+ result = trtllm .llm .update_weights_from_ipc_handles (all_handles )
349+ print (f"update weights result: { result } " )
350+
351+ if rank == 0 :
292352 outputs = trtllm .llm .generate (prompts , sampling_params )
293353 for i , output in enumerate (outputs ):
294354 prompt = output .prompt
@@ -299,4 +359,5 @@ def main():
299359if __name__ == '__main__' :
300360 main ()
301361
302- # torchrun --nproc_per_node=2 generate.py --model_dir /model/Qwen2.5-0.5B-Instruct --tensor_parallel_size 2
362+ # torchrun --nproc_per_node=2 tests/unittest/llmapi/test_llm_update_weights.py --model_dir /model/Qwen2.5-0.5B-Instruct --tensor_parallel_size 2
363+ # torchrun --nproc_per_node=2 tests/unittest/llmapi/test_llm_update_weights.py --model_dir /model/Qwen2.5-3B-Instruct/ --tensor_parallel_size 2
0 commit comments