@@ -149,7 +149,11 @@ def do_fsdp(model):
149149 self .reference_model = do_fsdp (self .reference_model )
150150 self .reference_model = self .manual_offload_to_cpu (self .reference_model )
151151 self .model = self .manual_load_to_gpu (self .model )
152- self ._held_reference_model_params = None
152+
153+ # used for streaming update inference engine weights
154+ self ._held_sharded_state_dict_reference = None
155+ self ._held_streamed_param_reference = None
156+
153157 # register_fsdp_forward_method(self.model, "generate")
154158 if init_optimizer :
155159 optimizer_cls = import_class_from_path (self .cfg ["optimizer" ]["name" ])
@@ -205,6 +209,9 @@ def do_fsdp(model):
205209 def is_alive (self ):
206210 return True
207211
212+ def reset_peak_memory_stats (self ):
213+ torch .cuda .reset_peak_memory_stats ()
214+
208215 def get_gpu_info (self ):
209216 """Return information about the GPU being used by this worker."""
210217 return get_gpu_info (self .model )
@@ -720,38 +727,61 @@ def report_device_id(self) -> str:
720727 return get_device_uuid (device_idx )
721728
722729 @torch .no_grad ()
723- def get_weight_ipc_handles (self , offload_model = True ):
724- from torch .multiprocessing . reductions import reduce_tensor
730+ def prepare_weights_for_ipc (self ):
731+ from torch .distributed . fsdp . api import ShardedStateDictConfig , StateDictType
725732
726733 # If the model is not FSDP, then we need to manually move it to the GPU
727734 # For an FSDP model, model.state_dict() will move the params to the GPU
728- if not isinstance (self .model , torch . distributed . fsdp . FullyShardedDataParallel ):
735+ if not isinstance (self .model , FullyShardedDataParallel ):
729736 self .model = self .manual_load_to_gpu (self .model )
737+ self ._held_sharded_state_dict_reference = self .model .state_dict ()
738+ else :
739+ # Get sharded state dict instead of full state dict for FSDP1
740+ with FullyShardedDataParallel .state_dict_type (
741+ self .model ,
742+ state_dict_type = StateDictType .SHARDED_STATE_DICT ,
743+ state_dict_config = ShardedStateDictConfig (),
744+ ):
745+ self ._held_sharded_state_dict_reference = self .model .state_dict ()
746+
747+ # Collect info for streaming multiple tensors
748+ state_dict_info = []
749+ for name , tensor in self ._held_sharded_state_dict_reference .items ():
750+ # dtensor's numel will return complete tensor instead of only local tensor
751+ size_in_bytes = tensor .element_size () * tensor .numel ()
752+ state_dict_info .append ((name , size_in_bytes ))
753+
754+ return state_dict_info
730755
731- # TODO @sahilj: do this without an allgather (maybe FSDP2)
732- params = self .model .state_dict ()
756+ @torch .no_grad ()
757+ def get_weights_ipc_handles (self , keys ):
758+ from torch .distributed .tensor import DTensor
759+ from torch .multiprocessing .reductions import reduce_tensor
733760
734- # Create a copy of parameters in the desired dtype (bfloat16 or float32)
735- dtype_params = {}
736- for name , param in params .items ():
761+ converted_params = {}
762+ for key in keys :
763+ # Get full_tensor for dtensor (GPU > 1)
764+ tensor = self ._held_sharded_state_dict_reference [key ]
765+ if isinstance (tensor , DTensor ):
766+ full_tensor = tensor .full_tensor ()
767+ else :
768+ full_tensor = tensor
737769 # Convert parameters to the configured dtype
738- dtype_params [name ] = param .to (self .dtype , non_blocking = True )
739-
740- # Replace the original params with the converted ones
741- params = dtype_params
742- # For FSDP1, params may get GC'ed before sending to vllm,
743- # so we need to hold a reference to them
744- self ._held_reference_model_params = params
745- data = {}
770+ converted_params [key ] = full_tensor .to (self .dtype , non_blocking = True )
771+
772+ # Temporary record the full tensor for cleanup
773+ # It is needed for cleanup the last full_tensor in the refit process
774+ self ._held_streamed_param_reference = converted_params
775+
776+ # Get device UUID for IPC
746777 device_uuid = self .report_device_id ()
747- for name , p in params .items ():
748- data [name ] = reduce_tensor (p .detach ())
778+ # Create handles for the tensors
779+ all_handles = []
780+ for key , p in converted_params .items ():
781+ handle = reduce_tensor (p .detach ())
782+ all_handles .append ((key , handle ))
749783
750- if offload_model :
751- self .model = self .manual_offload_to_cpu (self .model )
752- gc .collect ()
753- torch .cuda .empty_cache ()
754- return {device_uuid : data }
784+ return {device_uuid : all_handles }
755785
756786 def prepare_for_lp_inference (self ):
757787 self .model = self .manual_load_to_gpu (self .model )
@@ -802,9 +832,13 @@ def offload_after_refit(self):
802832 torch .randn (1 ).cuda () # wake up torch allocator
803833 self .offload_before_refit () # rerun the old offload function
804834
805- if self ._held_reference_model_params is not None :
806- del self ._held_reference_model_params
807- self ._held_reference_model_params = None
835+ # Clean up the held tensors
836+ if self ._held_sharded_state_dict_reference is not None :
837+ del self ._held_sharded_state_dict_reference
838+ self ._held_sharded_state_dict_reference = None
839+ if self ._held_streamed_param_reference is not None :
840+ del self ._held_streamed_param_reference
841+ self ._held_streamed_param_reference = None
808842
809843 gc .collect ()
810844 torch .cuda .empty_cache ()
0 commit comments