1616)
1717#from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
1818from transformers import AutoModelForCausalLM , AutoTokenizer
19-
19+ from torch . distributed . tensor import DTensor
2020
2121def init_distributed ():
2222 """Initialize distributed training"""
@@ -29,7 +29,7 @@ def init_distributed():
2929 if "MASTER_PORT" not in os .environ :
3030 os .environ ["MASTER_PORT" ] = "29500"
3131
32- dist .init_process_group (backend = "nccl" )
32+ dist .init_process_group (backend = "cpu:gloo,cuda: nccl" )
3333 world_size = dist .get_world_size ()
3434 rank = dist .get_rank ()
3535 torch .cuda .set_device (rank )
@@ -39,6 +39,18 @@ def exit_distributed():
3939 """Exit distributed training"""
4040 if dist .is_initialized ():
4141 dist .destroy_process_group ()
42+
43+ def report_device_id () -> str :
44+ """Report the UUID of the current CUDA device using NVML.
45+ Returns:
46+ str: UUID of the device in the format "GPU-xxxxx"
47+ """
48+ from tensorrt_llm ._torch .utils import get_device_uuid
49+ # Get current device index from torch
50+ device_idx = torch .cuda .current_device ()
51+ # Get device UUID using NVML
52+ return get_device_uuid (device_idx )
53+
4254class fsdp_interface :
4355 def __init__ (self , model_dir ):
4456 self .model_dir = model_dir
@@ -96,17 +108,23 @@ def load_fsdp_model(self, model_dir):
96108 return fsdp_model
97109
98110
99- def report_device_id (self ) -> str :
100- """Report the UUID of the current CUDA device using NVML.
101111
102- Returns:
103- str: UUID of the device in the format "GPU-xxxxx"
104- """
105- from tensorrt_llm ._torch .utils import get_device_uuid
106- # Get current device index from torch
107- device_idx = torch .cuda .current_device ()
108- # Get device UUID using NVML
109- return get_device_uuid (device_idx )
112+ def per_tensor_generator (self ):
113+ # If the model is not FSDP, then we need to manually move it to the GPU
114+ # For an FSDP model, model.state_dict() will move the params to the GPU
115+ if not isinstance (self .model , FSDP ):
116+ self .model = self .manual_load_to_gpu (self .model )
117+ self ._held_sharded_state_dict_reference = self .model .state_dict ()
118+ else :
119+ # Get sharded state dict instead of full state dict for FSDP1
120+ with FSDP .state_dict_type (
121+ self .model ,
122+ state_dict_type = StateDictType .FULL_STATE_DICT ,
123+ state_dict_config = FullStateDictConfig ()
124+ ):
125+ self ._held_sharded_state_dict_reference = self .model .state_dict ()
126+ for name , param in self ._held_sharded_state_dict_reference .items ():
127+ yield name , param
110128
111129 @torch .no_grad ()
112130 def prepare_weights_for_ipc (self ) -> tuple [list [tuple [str , int ]], float ]:
@@ -182,7 +200,7 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]:
182200 self ._held_streamed_param_reference = converted_params
183201
184202 # Get device UUID for IPC
185- device_uuid = self . report_device_id ()
203+ device_uuid = report_device_id ()
186204 # Create handles for the tensors
187205 all_handles = []
188206 for key , p in converted_params .items ():
@@ -231,6 +249,25 @@ def prepare_weights_for_ipc_refit(
231249
232250 return grouped_param_keys
233251
252+ class NamedParam :
253+ def __init__ (self , name , size , param ):
254+ self .name = name
255+ self .size = size
256+ self .param = param
257+
258+ class GateAndUp :
259+ def __init__ (self ):
260+ self .gate = None
261+ self .up = None
262+ def set_gate (self , gate ):
263+ self .gate = gate
264+ def set_up (self , up ):
265+ self .up = up
266+ def get_size (self ):
267+ return self .gate .size + self .up .size
268+ def is_complete (self ):
269+ return self .gate is not None and self .up is not None
270+
234271class trtllm_interface :
235272 def __init__ (self , model_dir , tensor_parallel_size ):
236273 self .world_size = dist .get_world_size ()
@@ -257,13 +294,104 @@ def load_trtllm_model(self, model_dir, tensor_parallel_size):
257294 else :
258295 return None
259296
297+ def update_weights_from_ipc_handles (self , rank , device_handles ):
298+ if rank == 0 :
299+ gathered_handles = [None for _ in range (dist .get_world_size ())]
300+ else :
301+ gathered_handles = None
302+ dist .gather_object (
303+ obj = device_handles ,
304+ object_gather_list = gathered_handles ,
305+ dst = 0
306+ )
307+ if rank == 0 :
308+ all_handles = {k : v for d in gathered_handles for k , v in d .items ()}
309+ self .llm .update_weights_from_ipc_handles (all_handles )
310+
311+ def update_weights_from_tensor_generator (self , tensor_generator ):
312+ device_uuid = report_device_id ()
313+ rank = dist .get_rank ()
314+ from torch .multiprocessing .reductions import reduce_tensor
315+ total_available_bytes = 0.7 * (1024 ** 3 )
316+ cur_available_bytes = total_available_bytes
317+ converted_params = {}
318+ cur_handles = []
319+ gate_up = {}
320+ for name , param in tensor_generator :
321+ size_in_bytes = param .element_size () * param .numel ()
322+ if isinstance (param , DTensor ):
323+ param = param .full_tensor ()
324+ gate_up_name = None
325+ gate_up_pair = None
326+ if "gate_proj" in name :
327+ gate_up_name = name .replace ("gate_proj" , "" )
328+ if (gate_up_name not in gate_up ):
329+ gate_up [gate_up_name ] = GateAndUp ()
330+ assert gate_up [gate_up_name ].gate is None
331+ gate_up [gate_up_name ].set_gate (NamedParam (name , size_in_bytes , param ))
332+ elif "up_proj" in name :
333+ gate_up_name = name .replace ("up_proj" , "" )
334+ if (gate_up_name not in gate_up ):
335+ gate_up [gate_up_name ] = GateAndUp ()
336+ assert gate_up [gate_up_name ].up is None
337+ gate_up [gate_up_name ].set_up (NamedParam (name , size_in_bytes , param ))
338+ if (gate_up_name is not None ):
339+ if gate_up [gate_up_name ].is_complete ():
340+ gate_up_pair = gate_up .pop (gate_up_name )
341+ size_in_bytes = gate_up_pair .get_size ()
342+ else :
343+ continue
344+
345+ if size_in_bytes > cur_available_bytes :
346+ device_handles = {device_uuid : cur_handles }
347+ self .update_weights_from_ipc_handles (rank , device_handles )
348+ cur_available_bytes = total_available_bytes
349+ del converted_params
350+ converted_params = {}
351+ cur_handles = []
352+
353+ assert cur_available_bytes >= size_in_bytes
354+ cur_available_bytes -= size_in_bytes
355+ if (gate_up_pair is not None ):
356+ converted_params [gate_up_pair .gate .name ] = gate_up_pair .gate .param
357+ converted_params [gate_up_pair .up .name ] = gate_up_pair .up .param
358+ handle = reduce_tensor (gate_up_pair .gate .param .detach ())
359+ cur_handles .append ((gate_up_pair .gate .name , handle ))
360+ handle = reduce_tensor (gate_up_pair .up .param .detach ())
361+ cur_handles .append ((gate_up_pair .up .name , handle ))
362+ gate_up_pair = None
363+ else :
364+ converted_params [name ] = param
365+ handle = reduce_tensor (param .detach ())
366+ cur_handles .append ((name , handle ))
367+
368+ assert len (gate_up ) == 0
369+
370+ if cur_handles :
371+ device_handles = {device_uuid : cur_handles }
372+ self .update_weights_from_ipc_handles (rank , device_handles )
373+ cur_available_bytes = total_available_bytes
374+ del converted_params
375+ converted_params = {}
376+ cur_handles = []
377+
378+ def get_total_available_bytes (pg : dist .ProcessGroup , message : str = "" ) -> int :
379+ mem_allocated = torch .cuda .memory_allocated ()
380+ mem_reserved = torch .cuda .memory_reserved ()
381+ mem_free , mem_total = torch .cuda .mem_get_info ()
382+ print (f"{ message } mem_free: { mem_free :,} , mem_total: { mem_total :,} , mem_allocated: { mem_allocated :,} , mem_reserved: { mem_reserved :,} " )
383+ mem_free = torch .tensor (mem_free )
384+ dist .all_reduce (mem_free , op = dist .ReduceOp .MIN , group = pg )
385+ mem_free = mem_free .item ()
386+ print (f"{ message } gathered_mem_free: { mem_free :,} " )
387+ return mem_free * 0.2
388+
260389def cleanup ():
261390 """Cleanup function to destroy process group"""
262391 if dist .is_initialized ():
263392 print (f"Cleaning up process group on rank { dist .get_rank ()} " )
264393 dist .destroy_process_group ()
265394
266-
267395def main ():
268396 parser = argparse .ArgumentParser (
269397 description = "LLM models with the PyTorch workflow." )
@@ -306,7 +434,6 @@ def main():
306434 # For FSDP mode, we would need additional logic to integrate withTensorRT-LLM
307435 # This is a placeholder for now
308436 if rank == 0 :
309-
310437 outputs = trtllm .llm .generate (prompts , sampling_params )
311438 for i , output in enumerate (outputs ):
312439 prompt = output .prompt
@@ -321,33 +448,9 @@ def main():
321448 result = trtllm .llm .wakeup ()
322449 print (f"wakeup result: { result } " )
323450
324- dict_info , total_available_bytes = fsdp .prepare_weights_for_ipc ()
325-
326- grouped_param_keys = fsdp .prepare_weights_for_ipc_refit (0.5 )
327- total_num_keys = sum (len (k ) for k in grouped_param_keys )
328- print (
329- f"[Refit] Split { total_num_keys } keys into { len (grouped_param_keys )} groups"
330- )
331-
332- from tensorrt_llm ._torch .utils import get_free_memory_bytes
333- for keys in grouped_param_keys :
334- handles = fsdp .get_weights_ipc_handles (keys )
335- #print(f"handles: {handles}")
336-
337- # Collect handles from all ranks
338- all_handles = [None for _ in range (world_size )]
339- dist .all_gather_object (all_handles , handles )
340- all_handles = {k : v for d in all_handles for k , v in d .items ()}
341- #print(f"all_handles: {all_handles.keys()}")
342-
343- device_idx = torch .cuda .current_device ()
344- total_available_bytes = get_free_memory_bytes (device_idx )
345- print (f"total_available_bytes: { total_available_bytes } " )
346-
347- if rank == 0 :
348- result = trtllm .llm .update_weights_from_ipc_handles (all_handles )
349- print (f"update weights result: { result } " )
451+ trtllm .update_weights_from_tensor_generator (fsdp .per_tensor_generator ())
350452
453+ # generate the output again
351454 if rank == 0 :
352455 outputs = trtllm .llm .generate (prompts , sampling_params )
353456 for i , output in enumerate (outputs ):
0 commit comments