Skip to content

Commit fc540ce

Browse files
tushar00jainpytorchmergebot
authored andcommitted
set pg name based on ranks (pytorch#166182)
Summary: - in torchft we have multiple default pg's, 1 for each task group - for flight recorder to work, each of these need to have a different name, so entries can be matched - change the `init_process_group` api to optionally take a list of ranks. if provided, we use the hash of the ranks as the name of the pg. for torchft, we'll pass global ranks here so the default pg have a different name on each task group Pull Request resolved: pytorch#166182 Approved by: https://github.com/fduwjj
1 parent d1a6e00 commit fc540ce

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torch/distributed/distributed_c10d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,7 @@ def init_process_group(
15831583
group_name: str = "",
15841584
pg_options: Optional[Any] = None,
15851585
device_id: Optional[Union[torch.device, int]] = None,
1586+
_ranks: Optional[list[int]] = None,
15861587
) -> None:
15871588
"""
15881589
Initialize the default distributed process group.
@@ -1657,6 +1658,8 @@ def init_process_group(
16571658
want to know NCCL initialization error early, you can also use this
16581659
field. If an `int` is provided, the API assumes that the accelerator
16591660
type at compile time will be used.
1661+
_ranks: The ranks in the process group. If provided, the process
1662+
group name will be the hash of all the ranks in the group.
16601663
16611664
.. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source
16621665
on a system that supports MPI.
@@ -1761,7 +1764,10 @@ def init_process_group(
17611764
internals of c10d. This means we can ignore the value
17621765
they provide as it not exposed in a public way.
17631766
"""
1764-
group_name = _process_group_name([], use_hashed_name=False)
1767+
if _ranks is None or len(_ranks) == 0:
1768+
group_name = _process_group_name([], use_hashed_name=False)
1769+
else:
1770+
group_name = _process_group_name(_ranks, use_hashed_name=True)
17651771
if backend == Backend.MPI:
17661772
if world_size != -1 or rank != -1:
17671773
warnings.warn(

0 commit comments

Comments
 (0)