@@ -786,20 +786,6 @@ def _get_master_port(master_port: int | None = None) -> int:
786786 return master_port
787787
788788
789- def _get_bcast_rank_map (world_size : int , ranks : list [int ] | None ) -> dict [int , int ]:
790- """
791- map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
792- which are generated in self.init_process_group_for_ranks
793- """
794- bcast_rank_map : dict [int , int ] = {}
795- if not ranks :
796- bcast_rank_map = {r : r for r in range (world_size )}
797- else :
798- for i , r in enumerate (ranks ):
799- bcast_rank_map [r ] = i
800- return bcast_rank_map
801-
802-
803789class P2PStore :
804790 def __init__ (self , device_manager : DeviceManager ):
805791 from mooncake .engine import TransferEngine
@@ -1164,12 +1150,36 @@ def init_process_group(
11641150 )
11651151 logger .info (f"[rank{ self ._rank } ] init process group successfully." )
11661152
1153+ def store_based_barrier (
1154+ self , store : dist .TCPStore , timeout : timedelta = timedelta (minutes = 5 )
1155+ ) -> None :
1156+ """
1157+ Perform a store-based barrier synchronization across all ranks.
1158+
1159+ This barrier uses a TCP store directly rather than a process group,
1160+ allowing all ranks to synchronize regardless of which process group
1161+ they belong to.
1162+
1163+ Args:
1164+ store: The TCPStore instance to use for synchronization.
1165+ """
1166+ dist .distributed_c10d ._store_based_barrier (
1167+ rank = self ._rank ,
1168+ store = store ,
1169+ group_name = "parameter_server_barrier" ,
1170+ rendezvous_count = self ._world_size ,
1171+ timeout = timeout ,
1172+ )
1173+
11671174 def update (
11681175 self ,
11691176 checkpoint_name : str ,
11701177 req_func : Callable [[list [tuple [str , str ]]], None ],
11711178 * ,
1179+ timeout : timedelta = timedelta (minutes = 10 ),
11721180 ranks : list [int ] | None = None ,
1181+ master_addr : str | None = None ,
1182+ master_port : int | None = None ,
11731183 ) -> None :
11741184 """
11751185 Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -1181,34 +1191,45 @@ def update(
11811191 which is the fastest way to update weights, especially in colocated architecture.
11821192 If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
11831193 which is useful in disaggregated architecture.
1194+ master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
1195+ master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
1196+ timeout: The timeout of the barrier operation.
11841197 """
11851198 assert req_func is not None , "req_func is required"
1199+ ranks_group = None
11861200 try :
1187- # if both ranks is None or [], it will use fully broadcast to update to all ranks
1188- if not ranks :
1189- if self ._auto_pg and not dist .is_initialized ():
1190- self .init_process_group ()
1191- self ._update_per_bucket (checkpoint_name , req_func )
1201+ master_addr = os .getenv ("MASTER_ADDR" ) or master_addr
1202+ assert master_addr , "master_addr is required"
1203+ if self ._auto_pg :
1204+ if not dist .is_initialized ():
1205+ self .init_process_group (
1206+ timeout = timeout , master_addr = master_addr , master_port = master_port
1207+ )
1208+ manager_store = dist .distributed_c10d ._get_default_store ()
11921209 else :
1193- if self ._auto_pg :
1194- if dist .is_initialized ():
1195- dist .destroy_process_group ()
1196- # HACK: wait 2s to ensure destroy is finished
1197- time .sleep (2 )
1198- self .init_process_group_for_ranks (ranks )
1199- if self ._rank not in ranks :
1200- return
1201- self ._update_per_bucket (checkpoint_name , req_func , ranks )
1202-
1210+ # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
1211+ # If master_port is provided, use master_port+1 for barrier store
1212+ manager_store = dist .TCPStore (
1213+ master_addr ,
1214+ _get_master_port (master_port ) + 1 ,
1215+ self ._world_size ,
1216+ timeout = timeout ,
1217+ is_master = self ._rank == 0 ,
1218+ )
1219+ # if ranks is None or [], it will use fully broadcast to update to all ranks
1220+ ranks_group = dist .new_group (ranks if ranks else None )
1221+ self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
1222+ self .store_based_barrier (manager_store )
12031223 except Exception as e :
12041224 logger .exception (
12051225 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } error { e } "
12061226 )
12071227 raise
12081228 finally :
1209- if self ._auto_pg and (not ranks or self ._rank in ranks ):
1229+ if ranks_group :
1230+ dist .destroy_process_group (ranks_group )
1231+ if self ._auto_pg and dist .is_initialized ():
12101232 dist .destroy_process_group ()
1211-
12121233 self .device_manager .device_module .empty_cache ()
12131234 logger .info (
12141235 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } done. "
@@ -1226,7 +1247,9 @@ def zmq_handle(device_uuid: str) -> str:
12261247 self ._zmq_addr_counter += 1
12271248 return socket , socket_paths
12281249
1229- def _detect_bucket_size (self , * , disable_h2d_buffer : bool = False ) -> tuple [int , bool ]:
1250+ def _detect_bucket_size (
1251+ self , ranks_group : dist .ProcessGroup , * , disable_h2d_buffer : bool = False
1252+ ) -> tuple [int , bool ]:
12301253 GiB = 1 << 30 # noqa: N806
12311254 # auto detect bucket size
12321255 tensor = torch .tensor (
@@ -1242,7 +1265,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
12421265 dtype = torch .int64 ,
12431266 device = self .device_manager .device_type ,
12441267 )
1245- dist .all_reduce (tensor , op = dist .ReduceOp .MIN )
1268+ dist .all_reduce (tensor , op = dist .ReduceOp .MIN , group = ranks_group )
12461269 tensor = tensor .cpu ()
12471270 free_bytes , self ._zmq_addr_counter = tensor [0 ].item (), - tensor [1 ].item ()
12481271 max_tensor_bytes = 0
@@ -1305,51 +1328,6 @@ def _copy_to_buffer(
13051328 self ._p2p_store .batch_transfer_sync_read (target_addr , buf_ptrs , remote_ptrs , lens )
13061329 self .device_manager .device_module .synchronize ()
13071330
1308- def init_process_group_for_ranks (
1309- self ,
1310- ranks : list [int ],
1311- * ,
1312- master_port : int | None = None ,
1313- timeout : timedelta = timedelta (minutes = 10 ),
1314- ):
1315- """
1316- Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
1317-
1318- Args:
1319- ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
1320- master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
1321- timeout: The timeout of the process group.
1322- """
1323- assert not dist .is_initialized ()
1324- assert ranks , "ranks should be set"
1325- if self ._rank not in ranks :
1326- return
1327- assert self ._all_hosts , "all_hosts should be set"
1328- assert len (self ._all_hosts ) == self ._world_size // self ._gpu_count , (
1329- f"world_size { self ._world_size } should be equal to all_hosts { len (self ._all_hosts )} "
1330- )
1331- rank = ranks .index (self ._rank )
1332- master_addr = self ._all_hosts [ranks [0 ] // self ._gpu_count ]
1333- master_port = _get_master_port (master_port )
1334- logger .info (
1335- f"[rank{ self ._rank } ] start to init process group as virtual_rank { rank } , "
1336- f"master_addr { master_addr } , master_port { master_port } , world_size { len (ranks )} , "
1337- )
1338- # only initialize process group and store for ranks, other nodes are not initialized
1339- # and will not participate in this update. Since they have registered memory addresses
1340- # to p2p_store at the beginning, update ranks can directly get the memory addresses
1341- # from other nodes and put the weights into the buffer.
1342- store = dist .TCPStore (
1343- master_addr , master_port , len (ranks ), is_master = rank == 0 , timeout = timeout
1344- )
1345- dist .init_process_group (
1346- backend = self .device_manager .backend ,
1347- world_size = len (ranks ),
1348- rank = rank ,
1349- timeout = timeout ,
1350- store = store ,
1351- )
1352-
13531331 def _get_addr_ptrs (self , owner_rank : int ) -> tuple [str , list [tuple [int , int ]]]:
13541332 addr = self ._current_global_parameter_metas [owner_rank ].p2p_store_addr
13551333 metas_list = self ._current_global_parameter_metas [owner_rank ].memory_buffer_metas_list
@@ -1389,10 +1367,12 @@ def _update_per_bucket(
13891367 self ,
13901368 checkpoint_name : str ,
13911369 req_func : Callable [[list [tuple [str , str ]]], None ],
1370+ ranks_group : dist .ProcessGroup ,
13921371 ranks : list [int ] | None = None ,
13931372 ):
13941373 assert len (self ._current_global_parameter_metas ) != 0 , "parameter metas is empty"
13951374 assert dist .is_initialized (), "process group is not initialized"
1375+
13961376 # if both ranks is None or [], it will use fully broadcast to update to all ranks
13971377 if not ranks :
13981378 logger .info (f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } " )
@@ -1410,9 +1390,9 @@ def _update_per_bucket(
14101390 if not need_update :
14111391 return
14121392 # first execute a barrier to avoid subsequent device oom
1413- dist .barrier ()
1393+ dist .barrier (group = ranks_group )
14141394
1415- bucket_size , disable_h2d_buffer = self ._detect_bucket_size ()
1395+ bucket_size , disable_h2d_buffer = self ._detect_bucket_size (ranks_group )
14161396 buckets = _gen_h2d_buckets (
14171397 self ._current_global_parameter_metas ,
14181398 bucket_size ,
@@ -1459,7 +1439,6 @@ def _update_per_bucket(
14591439
14601440 gidx = 0
14611441 ret_code = torch .zeros ((), device = self .device_manager .device_type , dtype = torch .int64 )
1462- bcast_rank_map = _get_bcast_rank_map (self ._world_size , ranks )
14631442 try :
14641443 for i in range (max_len ):
14651444 if i < len (receiver_rank_buckets ) and not disable_h2d_buffer :
@@ -1489,16 +1468,15 @@ def _update_per_bucket(
14891468 self ._copy_to_buffer (checkpoint_name , bucket , buffer_b )
14901469 else :
14911470 buffer_b .data .copy_ (h2d_buffer [: bucket .size ])
1492- brank = bcast_rank_map [receiver_rank ]
1493- dist .broadcast (buffer_b , src = brank )
1471+ dist .broadcast (buffer_b , src = receiver_rank , group = ranks_group )
14941472 resp = socket .recv ()
14951473 if resp != b"" :
14961474 msg = resp .decode ("utf-8" )
14971475 logger .error (
14981476 f"[rank{ self ._rank } ] receive error response from rank { receiver_rank } for bucket { gidx } in checkpoint { checkpoint_name } : { msg } "
14991477 )
15001478 ret_code .fill_ (1 )
1501- dist .all_reduce (ret_code , op = dist .ReduceOp .SUM )
1479+ dist .all_reduce (ret_code , op = dist .ReduceOp .SUM , group = ranks_group )
15021480 self .device_manager .device_module .synchronize ()
15031481 if ret_code .item () != 0 :
15041482 # quit early if any rank failed
@@ -1512,7 +1490,7 @@ def _update_per_bucket(
15121490 socket .recv ()
15131491 finally :
15141492 req_thread .join ()
1515- dist .barrier ()
1493+ dist .barrier (group = ranks_group )
15161494 socket .close ()
15171495 if ranks and h2d_buffer is not None :
15181496 self ._p2p_store .unregister_named_tensors ([h2d_buffer_name ])
0 commit comments