@@ -3598,6 +3598,24 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
35983598 Async work handle, if async_op is set to True.
35993599 None, if not async_op or if not part of the group
36003600
3601+ .. note:: Note that all Tensors in gather_list must have the same size.
3602+
3603+ Example::
3604+ >>> # xdoctest: +SKIP("no rank")
3605+ >>> # We have 2 process groups, 2 ranks.
3606+ >>> tensor_size = 2
3607+ >>> device = torch.device(f'cuda:{rank}')
3608+ >>> tensor = torch.ones(tensor_size, device=device) + rank
3609+ >>> if dist.get_rank() == 0:
3610+ >>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)]
3611+ >>> else:
3612+ >>> gather_list = None
3613+ >>> dist.gather(tensor, gather_list, dst=0)
3614+ >>> # Rank 0 gets gathered data.
3615+ >>> gather_list
3616+ [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0
3617+ None # Rank 1
3618+
36013619 """
36023620 _check_single_tensor (tensor , "tensor" )
36033621
@@ -3665,19 +3683,21 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
36653683 >>> # Note: Process group initialization omitted on each rank.
36663684 >>> import torch.distributed as dist
36673685 >>> tensor_size = 2
3668- >>> t_ones = torch.ones(tensor_size)
3669- >>> t_fives = torch.ones(tensor_size) * 5
3670- >>> output_tensor = torch.zeros(tensor_size)
3686+ >>> device = torch.device(f'cuda:{rank}')
3687+ >>> output_tensor = torch.zeros(tensor_size, device=device)
36713688 >>> if dist.get_rank() == 0:
36723689 >>> # Assumes world_size of 2.
36733690 >>> # Only tensors, all of which must be the same size.
3691+ >>> t_ones = torch.ones(tensor_size, device=device)
3692+ >>> t_fives = torch.ones(tensor_size, device=device) * 5
36743693 >>> scatter_list = [t_ones, t_fives]
36753694 >>> else:
36763695 >>> scatter_list = None
36773696 >>> dist.scatter(output_tensor, scatter_list, src=0)
3678- >>> # Rank i gets scatter_list[i]. For example, on rank 1:
3697+ >>> # Rank i gets scatter_list[i].
36793698 >>> output_tensor
3680- tensor([5., 5.])
3699+ tensor([1., 1.], device='cuda:0') # Rank 0
3700+ tensor([5., 5.], device='cuda:1') # Rank 1
36813701
36823702 """
36833703 _check_single_tensor (tensor , "tensor" )
0 commit comments