Skip to content

Commit 71eec6a

Browse files
ngimelpytorchmergebot
authored andcommitted
[dist] handle discontiguous allgather/reducescatter inputs (pytorch#163712)
Fixes pytorch#163483 Pull Request resolved: pytorch#163712 Approved by: https://github.com/ezyang, https://github.com/kwen2501
1 parent 0456b23 commit 71eec6a

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3821,6 +3821,27 @@ def test_allgather_base(self):
38213821
dist.all_gather_into_tensor(output_tensor, tensor)
38223822
self.assertEqual(output_tensor, tensor)
38233823

3824+
@requires_nccl()
3825+
@skip_if_lt_x_gpu(2)
3826+
def test_allgather_noncontig(self):
3827+
store = dist.FileStore(self.file_name, self.world_size)
3828+
dist.init_process_group(
3829+
"nccl",
3830+
world_size=self.world_size,
3831+
rank=self.rank,
3832+
store=store,
3833+
)
3834+
device = "cuda"
3835+
tensor = (
3836+
torch.arange(0, 16, device=torch.device(device))
3837+
.view(2, 2, 2, 2)
3838+
.to(memory_format=torch.channels_last)
3839+
)
3840+
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
3841+
dist.all_gather(tensor_list, tensor)
3842+
for o in tensor_list:
3843+
self.assertEqual(o, tensor)
3844+
38243845
@requires_nccl()
38253846
@skip_if_lt_x_gpu(1)
38263847
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])

torch/csrc/distributed/c10d/ProcessGroupGloo.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,8 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
13811381
// Use single flat output tensor.
13821382
// The first dimension corresponds to the index into outputs[N],
13831383
// so copying into the actual output later is easy.
1384-
at::Tensor flatOutputTensor = newLikeFlat(outputs[0]);
1384+
at::Tensor flatOutputTensor =
1385+
newLikeFlat(outputs[0], /*preserve_strides*/ false);
13851386
GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
13861387
gloo::allgather(opts);
13871388

@@ -1398,7 +1399,7 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
13981399
}
13991400

14001401
const std::vector<at::Tensor> getOutputTensors() override {
1401-
return {newLikeFlat(outputs[0])};
1402+
return {newLikeFlat(outputs[0], /*preserve_strides*/ false)};
14021403
}
14031404

14041405
void run() override {
@@ -1694,7 +1695,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
16941695
}
16951696

16961697
const std::vector<at::Tensor> getOutputTensors() override {
1697-
return {newLikeFlat(output_lists[0])};
1698+
return {newLikeFlat(output_lists[0], /*preserve_strides*/ false)};
16981699
}
16991700

17001701
void run() override {
@@ -1818,7 +1819,7 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
18181819
// This is later scattered to the separate output tensors.
18191820
at::Tensor flatOutputTensor;
18201821
if (context_->rank == root) {
1821-
flatOutputTensor = newLikeFlat(outputs[0]);
1822+
flatOutputTensor = newLikeFlat(outputs[0], /*preserve_strides*/ false);
18221823
GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
18231824
}
18241825

@@ -1841,7 +1842,8 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
18411842

18421843
const std::vector<at::Tensor> getOutputTensors() override {
18431844
return outputs.empty() ? std::vector<at::Tensor>{}
1844-
: std::vector<at::Tensor>{newLikeFlat(outputs[0])};
1845+
: std::vector<at::Tensor>{newLikeFlat(
1846+
outputs[0], /*preserve_strides*/ false)};
18451847
}
18461848

18471849
void run() override {
@@ -2057,7 +2059,8 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork {
20572059

20582060
const std::vector<at::Tensor> getInputTensors() override {
20592061
return inputs.empty() ? std::vector<at::Tensor>{}
2060-
: std::vector<at::Tensor>{newLikeFlat(inputs[0])};
2062+
: std::vector<at::Tensor>{newLikeFlat(
2063+
inputs[0], /*preserve_strides*/ false)};
20612064
}
20622065

20632066
const std::vector<at::Tensor> getOutputTensors() override {

torch/csrc/distributed/c10d/Utils.hpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,15 +444,30 @@ inline at::Tensor newLikeFlat(
444444
sizes, strides, t.options().memory_format(std::nullopt));
445445
}
446446

447-
inline at::Tensor newLikeFlat(std::vector<at::Tensor>& tensors) {
447+
inline at::Tensor newLikeFlat(
448+
std::vector<at::Tensor>& tensors,
449+
bool preserve_strides = true) {
448450
if (tensors.empty()) {
449451
TORCH_CHECK(false, "Received an empty list");
450452
}
451453
auto& t = tensors[0];
452454
at::DeviceGuard gpuGuard(t.device());
453455
std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())};
454456
sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
455-
return at::empty(sizes, t.options());
457+
if (t.is_contiguous() ||
458+
!preserve_strides) { // we are checking for memory format, so tensor might
459+
// not be contiguous
460+
// TODO handle all non-overlapping-and-dense, although if the strides
461+
// disagree in ranks we are opening a door for more bugs than currently
462+
// where channels-last might disagree between ranks
463+
// fast path, don't call empty_strided
464+
return at::empty(sizes, t.options());
465+
} else {
466+
// memory-dense, but not necessarily contiguous tensor
467+
std::vector<int64_t> strides{t.numel()};
468+
strides.insert(strides.end(), t.strides().begin(), t.strides().end());
469+
return at::empty_strided(sizes, strides, t.options());
470+
}
456471
}
457472

458473
inline std::vector<std::vector<int64_t>> getSizes(

0 commit comments

Comments
 (0)