Skip to content

Commit 5997354

Browse files
xu-songpytorchmergebot
authored andcommitted
Add more distributed examples (pytorch#130427)
1. Add `gather` example 2. Add device to `scatter` example Pull Request resolved: pytorch#130427 Approved by: https://github.com/kwen2501
1 parent df1eef9 commit 5997354

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

torch/distributed/distributed_c10d.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3598,6 +3598,24 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
35983598
Async work handle, if async_op is set to True.
35993599
None, if not async_op or if not part of the group
36003600
3601+
.. note:: Note that all Tensors in gather_list must have the same size.
3602+
3603+
Example::
3604+
>>> # xdoctest: +SKIP("no rank")
3605+
>>> # We have 2 process groups, 2 ranks.
3606+
>>> tensor_size = 2
3607+
>>> device = torch.device(f'cuda:{rank}')
3608+
>>> tensor = torch.ones(tensor_size, device=device) + rank
3609+
>>> if dist.get_rank() == 0:
3610+
>>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)]
3611+
>>> else:
3612+
>>> gather_list = None
3613+
>>> dist.gather(tensor, gather_list, dst=0)
3614+
>>> # Rank 0 gets gathered data.
3615+
>>> gather_list
3616+
[tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0
3617+
None # Rank 1
3618+
36013619
"""
36023620
_check_single_tensor(tensor, "tensor")
36033621

@@ -3665,19 +3683,21 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
36653683
>>> # Note: Process group initialization omitted on each rank.
36663684
>>> import torch.distributed as dist
36673685
>>> tensor_size = 2
3668-
>>> t_ones = torch.ones(tensor_size)
3669-
>>> t_fives = torch.ones(tensor_size) * 5
3670-
>>> output_tensor = torch.zeros(tensor_size)
3686+
>>> device = torch.device(f'cuda:{rank}')
3687+
>>> output_tensor = torch.zeros(tensor_size, device=device)
36713688
>>> if dist.get_rank() == 0:
36723689
>>> # Assumes world_size of 2.
36733690
>>> # Only tensors, all of which must be the same size.
3691+
>>> t_ones = torch.ones(tensor_size, device=device)
3692+
>>> t_fives = torch.ones(tensor_size, device=device) * 5
36743693
>>> scatter_list = [t_ones, t_fives]
36753694
>>> else:
36763695
>>> scatter_list = None
36773696
>>> dist.scatter(output_tensor, scatter_list, src=0)
3678-
>>> # Rank i gets scatter_list[i]. For example, on rank 1:
3697+
>>> # Rank i gets scatter_list[i].
36793698
>>> output_tensor
3680-
tensor([5., 5.])
3699+
tensor([1., 1.], device='cuda:0') # Rank 0
3700+
tensor([5., 5.], device='cuda:1') # Rank 1
36813701
36823702
"""
36833703
_check_single_tensor(tensor, "tensor")

0 commit comments

Comments
 (0)