@@ -459,17 +459,17 @@ def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
459459 logits .div_ (self .cfg ["generation" ]["temperature" ])
460460 return logits
461461
462- def init_collective (self , ip : str , port : int , world_size : int ) -> None :
463- """Initialize the collective communication."""
462+ def init_collective (
463+ self , ip : str , port : int , world_size : int , * , train_world_size : int
464+ ) -> None :
464465 from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
465466 from vllm .distributed .utils import StatelessProcessGroup
466467
467- if self .rank == 0 :
468- pg = StatelessProcessGroup .create (
469- host = ip , port = port , rank = 0 , world_size = world_size
470- )
471- device = torch .cuda .current_device ()
472- self .model_update_group = PyNcclCommunicator (pg , device = device )
468+ pg = StatelessProcessGroup .create (
469+ host = ip , port = port , rank = self .rank , world_size = world_size
470+ )
471+ device = torch .cuda .current_device ()
472+ self .model_update_group = PyNcclCommunicator (pg , device = device )
473473
474474 def is_alive (self ) -> bool :
475475 return True
@@ -1770,9 +1770,8 @@ def broadcast_weights_for_collective(self) -> None:
17701770 for _ , tensor in self .model .state_dict ().items ():
17711771 if isinstance (tensor , DTensor ):
17721772 tensor = tensor .full_tensor ()
1773- if self .rank == 0 :
1774- tensor = tensor .to (self .dtype , non_blocking = True )
1775- self .model_update_group .broadcast (tensor .data , src = 0 )
1773+ tensor = tensor .to (self .dtype , non_blocking = True )
1774+ self .model_update_group .broadcast (tensor .data , src = 0 )
17761775
17771776 # Manually move model to cpu for cpu offload case
17781777 # cpu offload needs model on CPU before model forward
0 commit comments