Skip to content

Commit baf6f61

Browse files
authored
fix: use tcp store_based_barrier to control p2p update synchronization (#51)
* fix: store based barrier for all processes' synchronization * refactor: rewrite update process management logic * misc: fix pr issues * fix: merge issues fixed * misc
1 parent 089d185 commit baf6f61

File tree

2 files changed

+70
-86
lines changed

2 files changed

+70
-86
lines changed

checkpoint_engine/ps.py

Lines changed: 63 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -786,20 +786,6 @@ def _get_master_port(master_port: int | None = None) -> int:
786786
return master_port
787787

788788

789-
def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
790-
"""
791-
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
792-
which are generated in self.init_process_group_for_ranks
793-
"""
794-
bcast_rank_map: dict[int, int] = {}
795-
if not ranks:
796-
bcast_rank_map = {r: r for r in range(world_size)}
797-
else:
798-
for i, r in enumerate(ranks):
799-
bcast_rank_map[r] = i
800-
return bcast_rank_map
801-
802-
803789
class P2PStore:
804790
def __init__(self, device_manager: DeviceManager):
805791
from mooncake.engine import TransferEngine
@@ -1164,12 +1150,36 @@ def init_process_group(
11641150
)
11651151
logger.info(f"[rank{self._rank}] init process group successfully.")
11661152

1153+
def store_based_barrier(
1154+
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
1155+
) -> None:
1156+
"""
1157+
Perform a store-based barrier synchronization across all ranks.
1158+
1159+
This barrier uses a TCP store directly rather than a process group,
1160+
allowing all ranks to synchronize regardless of which process group
1161+
they belong to.
1162+
1163+
Args:
1164+
store: The TCPStore instance to use for synchronization.
1165+
"""
1166+
dist.distributed_c10d._store_based_barrier(
1167+
rank=self._rank,
1168+
store=store,
1169+
group_name="parameter_server_barrier",
1170+
rendezvous_count=self._world_size,
1171+
timeout=timeout,
1172+
)
1173+
11671174
def update(
11681175
self,
11691176
checkpoint_name: str,
11701177
req_func: Callable[[list[tuple[str, str]]], None],
11711178
*,
1179+
timeout: timedelta = timedelta(minutes=10),
11721180
ranks: list[int] | None = None,
1181+
master_addr: str | None = None,
1182+
master_port: int | None = None,
11731183
) -> None:
11741184
"""
11751185
Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -1181,34 +1191,45 @@ def update(
11811191
which is the fastest way to update weights, especially in colocated architecture.
11821192
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
11831193
which is useful in disaggregated architecture.
1194+
master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
1195+
master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
1196+
timeout: The timeout of the barrier operation.
11841197
"""
11851198
assert req_func is not None, "req_func is required"
1199+
ranks_group = None
11861200
try:
1187-
# if both ranks is None or [], it will use fully broadcast to update to all ranks
1188-
if not ranks:
1189-
if self._auto_pg and not dist.is_initialized():
1190-
self.init_process_group()
1191-
self._update_per_bucket(checkpoint_name, req_func)
1201+
master_addr = os.getenv("MASTER_ADDR") or master_addr
1202+
assert master_addr, "master_addr is required"
1203+
if self._auto_pg:
1204+
if not dist.is_initialized():
1205+
self.init_process_group(
1206+
timeout=timeout, master_addr=master_addr, master_port=master_port
1207+
)
1208+
manager_store = dist.distributed_c10d._get_default_store()
11921209
else:
1193-
if self._auto_pg:
1194-
if dist.is_initialized():
1195-
dist.destroy_process_group()
1196-
# HACK: wait 2s to ensure destroy is finished
1197-
time.sleep(2)
1198-
self.init_process_group_for_ranks(ranks)
1199-
if self._rank not in ranks:
1200-
return
1201-
self._update_per_bucket(checkpoint_name, req_func, ranks)
1202-
1210+
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
1211+
# If master_port is provided, use master_port+1 for barrier store
1212+
manager_store = dist.TCPStore(
1213+
master_addr,
1214+
_get_master_port(master_port) + 1,
1215+
self._world_size,
1216+
timeout=timeout,
1217+
is_master=self._rank == 0,
1218+
)
1219+
# if ranks is None or [], it will use fully broadcast to update to all ranks
1220+
ranks_group = dist.new_group(ranks if ranks else None)
1221+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1222+
self.store_based_barrier(manager_store)
12031223
except Exception as e:
12041224
logger.exception(
12051225
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
12061226
)
12071227
raise
12081228
finally:
1209-
if self._auto_pg and (not ranks or self._rank in ranks):
1229+
if ranks_group:
1230+
dist.destroy_process_group(ranks_group)
1231+
if self._auto_pg and dist.is_initialized():
12101232
dist.destroy_process_group()
1211-
12121233
self.device_manager.device_module.empty_cache()
12131234
logger.info(
12141235
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
@@ -1226,7 +1247,9 @@ def zmq_handle(device_uuid: str) -> str:
12261247
self._zmq_addr_counter += 1
12271248
return socket, socket_paths
12281249

1229-
def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
1250+
def _detect_bucket_size(
1251+
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
1252+
) -> tuple[int, bool]:
12301253
GiB = 1 << 30 # noqa: N806
12311254
# auto detect bucket size
12321255
tensor = torch.tensor(
@@ -1242,7 +1265,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
12421265
dtype=torch.int64,
12431266
device=self.device_manager.device_type,
12441267
)
1245-
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
1268+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
12461269
tensor = tensor.cpu()
12471270
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
12481271
max_tensor_bytes = 0
@@ -1305,51 +1328,6 @@ def _copy_to_buffer(
13051328
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
13061329
self.device_manager.device_module.synchronize()
13071330

1308-
def init_process_group_for_ranks(
1309-
self,
1310-
ranks: list[int],
1311-
*,
1312-
master_port: int | None = None,
1313-
timeout: timedelta = timedelta(minutes=10),
1314-
):
1315-
"""
1316-
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
1317-
1318-
Args:
1319-
ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
1320-
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
1321-
timeout: The timeout of the process group.
1322-
"""
1323-
assert not dist.is_initialized()
1324-
assert ranks, "ranks should be set"
1325-
if self._rank not in ranks:
1326-
return
1327-
assert self._all_hosts, "all_hosts should be set"
1328-
assert len(self._all_hosts) == self._world_size // self._gpu_count, (
1329-
f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
1330-
)
1331-
rank = ranks.index(self._rank)
1332-
master_addr = self._all_hosts[ranks[0] // self._gpu_count]
1333-
master_port = _get_master_port(master_port)
1334-
logger.info(
1335-
f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
1336-
f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
1337-
)
1338-
# only initialize process group and store for ranks, other nodes are not initialized
1339-
# and will not participate in this update. Since they have registered memory addresses
1340-
# to p2p_store at the beginning, update ranks can directly get the memory addresses
1341-
# from other nodes and put the weights into the buffer.
1342-
store = dist.TCPStore(
1343-
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
1344-
)
1345-
dist.init_process_group(
1346-
backend=self.device_manager.backend,
1347-
world_size=len(ranks),
1348-
rank=rank,
1349-
timeout=timeout,
1350-
store=store,
1351-
)
1352-
13531331
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
13541332
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
13551333
metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
@@ -1389,10 +1367,12 @@ def _update_per_bucket(
13891367
self,
13901368
checkpoint_name: str,
13911369
req_func: Callable[[list[tuple[str, str]]], None],
1370+
ranks_group: dist.ProcessGroup,
13921371
ranks: list[int] | None = None,
13931372
):
13941373
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
13951374
assert dist.is_initialized(), "process group is not initialized"
1375+
13961376
# if both ranks is None or [], it will use fully broadcast to update to all ranks
13971377
if not ranks:
13981378
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
@@ -1410,9 +1390,9 @@ def _update_per_bucket(
14101390
if not need_update:
14111391
return
14121392
# first execute a barrier to avoid subsequent device oom
1413-
dist.barrier()
1393+
dist.barrier(group=ranks_group)
14141394

1415-
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
1395+
bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
14161396
buckets = _gen_h2d_buckets(
14171397
self._current_global_parameter_metas,
14181398
bucket_size,
@@ -1459,7 +1439,6 @@ def _update_per_bucket(
14591439

14601440
gidx = 0
14611441
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
1462-
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
14631442
try:
14641443
for i in range(max_len):
14651444
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
@@ -1489,16 +1468,15 @@ def _update_per_bucket(
14891468
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
14901469
else:
14911470
buffer_b.data.copy_(h2d_buffer[: bucket.size])
1492-
brank = bcast_rank_map[receiver_rank]
1493-
dist.broadcast(buffer_b, src=brank)
1471+
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
14941472
resp = socket.recv()
14951473
if resp != b"":
14961474
msg = resp.decode("utf-8")
14971475
logger.error(
14981476
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
14991477
)
15001478
ret_code.fill_(1)
1501-
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
1479+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
15021480
self.device_manager.device_module.synchronize()
15031481
if ret_code.item() != 0:
15041482
# quit early if any rank failed
@@ -1512,7 +1490,7 @@ def _update_per_bucket(
15121490
socket.recv()
15131491
finally:
15141492
req_thread.join()
1515-
dist.barrier()
1493+
dist.barrier(group=ranks_group)
15161494
socket.close()
15171495
if ranks and h2d_buffer is not None:
15181496
self._p2p_store.unregister_named_tensors([h2d_buffer_name])

tests/test_update.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,13 @@ def run_with_files(
237237
],
238238
),
239239
("test_with_remote_error", [[]]),
240-
# ("long_test_no_error", [list(random.sample(range(get_world_size()), k=num_ranks)) for num_ranks in range(get_world_size() + 1)]),
240+
(
241+
"test_no_error",
242+
[
243+
list(random.sample(range(get_world_size()), k=num_ranks))
244+
for num_ranks in range(get_world_size() + 1)
245+
],
246+
),
241247
],
242248
)
243249
def test_update(test_name: str, rank_list: list[list[int]] | None):

0 commit comments

Comments
 (0)