Skip to content

Commit 7b4ff01

Browse files
authored
refine barrier stream and add asyc_op to log (#1824)
1 parent 798a079 commit 7b4ff01

File tree

2 files changed

+108
-66
lines changed

2 files changed

+108
-66
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::send(
755755
std::vector<int64_t>(), // outSplitSizes
756756
-1, // globalRankStart
757757
-1, // globalRankStride
758-
this->getSize()); // worldSize
758+
this->getSize(), // worldSize
759+
"N/A"); // async_op
759760

760761
auto ret = pointToPoint(
761762
tensor,
@@ -804,7 +805,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::recv(
804805
std::vector<int64_t>(), // outSplitSizes
805806
-1, // globalRankStart
806807
-1, // globalRankStride
807-
this->getSize()); // worldSize
808+
this->getSize(), // worldSize
809+
"N/A"); // async_op
808810

809811
auto ret = pointToPoint(
810812
tensor,
@@ -889,7 +891,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::gather(
889891
std::vector<int64_t>(), // outSplitSize
890892
-1, // globalRankStart
891893
-1, // globalRankStride
892-
this->getSize()); // worldSize
894+
this->getSize(), // worldSize
895+
opts.asyncOp); // async_op
893896

894897
auto inputs = std::vector<at::Tensor>{inputTensor};
895898
return collective(
@@ -1003,7 +1006,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
10031006
std::vector<int64_t>(), // outSplitSize
10041007
-1, // globalRankStart
10051008
-1, // globalRankStride
1006-
this->getSize()); // worldSize
1009+
this->getSize(), // worldSize
1010+
opts.asyncOp); // async_op
10071011

10081012
const auto root = opts.rootRank;
10091013

@@ -1131,7 +1135,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
11311135
std::vector<int64_t>(), // outSplitSizes
11321136
-1, // globalRankStart
11331137
-1, // globalRankStride
1134-
size_); // worldSize
1138+
size_, // worldSize
1139+
opts.asyncOp); // async_op
11351140

11361141
return allreduce_impl(tensor, "xccl:all_reduce", opts);
11371142
}
@@ -1157,7 +1162,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
11571162
std::vector<int64_t>(), // outSplitSizes
11581163
-1, // globalRankStart
11591164
-1, // globalRankStride
1160-
this->getSize()); // worldSize
1165+
this->getSize(), // worldSize
1166+
opts.asyncOp); // async_op
11611167

11621168
return collectiveCoalesced(
11631169
tensors,
@@ -1219,7 +1225,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
12191225
std::vector<int64_t>(), // outSplitSizes
12201226
-1, // globalRankStart
12211227
-1, // globalRankStride
1222-
this->getSize()); // worldSize
1228+
this->getSize(), // worldSize
1229+
opts.asyncOp); // async_op
12231230

12241231
const auto root = opts.rootRank + opts.rootTensor;
12251232

@@ -1310,7 +1317,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
13101317
std::vector<int64_t>(), // outSplitSizes
13111318
-1, // globalRankStart
13121319
-1, // globalRankStride
1313-
this->getSize()); // worldSize
1320+
this->getSize(), // worldSize
1321+
opts.asyncOp); // async_op
13141322

13151323
return collective(
13161324
tensor,
@@ -1419,7 +1427,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather(
14191427
std::vector<int64_t>(), // outSplitSize
14201428
-1, // globalRankStart
14211429
-1, // globalRankStride
1422-
this->getSize()); // worldSize
1430+
this->getSize(), // worldSize
1431+
opts.asyncOp); // async_op
14231432

14241433
bool same_size = checkSameSize(outputTensors_);
14251434
if (same_size) {
@@ -1506,7 +1515,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_allgather_base(
15061515
std::vector<int64_t>(), // outSplitSize
15071516
-1, // globalRankStart
15081517
-1, // globalRankStride
1509-
this->getSize()); // worldSize
1518+
this->getSize(), // worldSize
1519+
opts.asyncOp); // async_op
15101520

15111521
return collective(
15121522
input_tensor,
@@ -1552,7 +1562,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather_into_tensor_coalesced(
15521562
std::vector<int64_t>(), // outSplitSizes
15531563
-1, // globalRankStart
15541564
-1, // globalRankStride
1555-
this->getSize()); // worldSize
1565+
this->getSize(), // worldSize
1566+
opts.asyncOp); // async_op
15561567

15571568
return collectiveCoalesced(
15581569
inputs,
@@ -1603,7 +1614,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
16031614
std::vector<int64_t>(), // outSplitSizes
16041615
-1, // globalRankStart
16051616
-1, // globalRankStride
1606-
this->getSize()); // worldSize
1617+
this->getSize(), // worldSize
1618+
opts.asyncOp); // async_op
16071619

16081620
bool same_size = checkSameSize(inputTensors_);
16091621
if (same_size) {
@@ -1700,7 +1712,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
17001712
std::vector<int64_t>(), // outSplitSizes
17011713
-1, // globalRankStart
17021714
-1, // globalRankStride
1703-
this->getSize()); // worldSize
1715+
this->getSize(), // worldSize
1716+
opts.asyncOp); // async_op
17041717

17051718
return collective(
17061719
inputTensor,
@@ -1740,7 +1753,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
17401753
std::vector<at::Tensor>& outputs,
17411754
std::vector<at::Tensor>& inputs,
17421755
const ReduceScatterOptions& opts) {
1743-
17441756
RECORD_PARAM_COMMS_DATA_WITH_LOG(
17451757
std::make_tuple(
17461758
static_cast<int64_t>(seqCollective_) + 1,
@@ -1758,7 +1770,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
17581770
std::vector<int64_t>(), // outSplitSizes
17591771
-1, // globalRankStart
17601772
-1, // globalRankStride
1761-
this->getSize()); // worldSize
1773+
this->getSize(), // worldSize
1774+
opts.asyncOp); // async_op
17621775

17631776
return collectiveCoalesced(
17641777
inputs,
@@ -1794,6 +1807,25 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
17941807
"xccl:reduce_scatter_tensor_coalesced");
17951808
}
17961809

1810+
c10::DeviceIndex ProcessGroupXCCL::guessDeviceId() const {
1811+
if (getBoundDeviceId().has_value()) {
1812+
return getBoundDeviceId().value().index();
1813+
} else if (!usedDeviceIdxs_.empty()) {
1814+
return *usedDeviceIdxs_.begin();
1815+
}
1816+
int devIdx =
1817+
static_cast<int16_t>(rank_ % at::detail::getXPUHooks().getNumGPUs());
1818+
LOG(WARNING)
1819+
<< logPrefix()
1820+
<< c10::str(
1821+
" using GPU ",
1822+
devIdx,
1823+
" as device used by this process is currently unknown. ",
1824+
"This can potentially cause a hang if this rank to GPU mapping is incorrect. ",
1825+
"You can specify device_id in init_process_group() to force use of a particular device.");
1826+
return static_cast<c10::DeviceIndex>(devIdx);
1827+
}
1828+
17971829
c10::intrusive_ptr<Work> ProcessGroupXCCL::barrier(const BarrierOptions& opts) {
17981830
RECORD_PARAM_COMMS(
17991831
static_cast<int>(
@@ -1810,18 +1842,13 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::barrier(const BarrierOptions& opts) {
18101842
-1, // globalRankStride
18111843
this->getSize()); // worldSize
18121844
// Device to use for barrier
1813-
int barDevIdx = -1;
1845+
c10::DeviceIndex barDevIdx = -1;
18141846

18151847
// See nccl barrier comments
18161848
if (!opts.device_ids.empty()) {
1817-
barDevIdx = opts.device_ids[0];
1818-
} else if (getBoundDeviceId()) {
1819-
barDevIdx = (*getBoundDeviceId()).index();
1820-
} else if (!usedDeviceIdxs_.empty()) {
1821-
barDevIdx = *usedDeviceIdxs_.begin();
1849+
barDevIdx = static_cast<c10::DeviceIndex>(opts.device_ids[0]);
18221850
} else {
1823-
barDevIdx =
1824-
static_cast<int16_t>(rank_ % at::detail::getXPUHooks().getNumGPUs());
1851+
barDevIdx = guessDeviceId();
18251852
}
18261853

18271854
TORCH_CHECK_WITH(
@@ -1833,12 +1860,20 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::barrier(const BarrierOptions& opts) {
18331860
at::Tensor barrierTensor =
18341861
at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat));
18351862

1836-
auto work = allreduce_impl(barrierTensor, "xccl:all_reduce_barrier");
1863+
AllreduceOptions arOpts = AllreduceOptions();
1864+
arOpts.asyncOp = opts.asyncOp;
1865+
auto work = allreduce_impl(barrierTensor, "xccl:all_reduce_barrier", arOpts);
1866+
1867+
if (opts.asyncOp) {
1868+
auto xcclWork = dynamic_cast<ProcessGroupXCCL::WorkXCCL*>(work.get());
1869+
TORCH_CHECK(xcclWork);
1870+
xcclWork->isBarrierOp_ = true;
1871+
return work;
1872+
}
18371873

1838-
auto xcclWork = dynamic_cast<ProcessGroupXCCL::WorkXCCL*>(work.get());
1839-
TORCH_CHECK(xcclWork);
1840-
xcclWork->isBarrierOp_ = true;
1841-
return work;
1874+
auto currentStream = at::xpu::getCurrentXPUStream(barDevIdx);
1875+
currentStream.synchronize();
1876+
return nullptr;
18421877
}
18431878

18441879
c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
@@ -1866,7 +1901,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
18661901
std::vector<int64_t>(), // outSplitSizes
18671902
-1, // globalRankStart
18681903
-1, // globalRankStride
1869-
this->getSize()); // worldSize
1904+
this->getSize(), // worldSize
1905+
opts.asyncOp); // async_op
1906+
18701907
TORCH_CHECK(
18711908
outputTensor.numel() == inputTensor.numel() &&
18721909
outputTensor.scalar_type() == inputTensor.scalar_type(),
@@ -1915,7 +1952,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
19151952
outputSplitSizes, // outSplitSizes
19161953
-1, // globalRankStart
19171954
-1, // globalRankStride
1918-
this->getSize()); // worldSize
1955+
this->getSize(), // worldSize
1956+
opts.asyncOp); // async_op
19191957

19201958
return collective(
19211959
inputTensor,
@@ -1991,7 +2029,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
19912029
std::vector<int64_t>(), // outSplitSizes
19922030
-1, // globalRankStart
19932031
-1, // globalRankStride
1994-
this->getSize()); // worldSize
2032+
this->getSize(), // worldSize
2033+
opts.asyncOp); // async_op
19952034

19962035
return collective(
19972036
inputTensors,

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {
367367

368368
const std::string& logPrefix() const;
369369

370+
c10::DeviceIndex guessDeviceId() const;
371+
370372
protected:
371373
std::unordered_map<std::string, std::pair<at::xpu::XPUStream, ccl::stream>>
372374
xcclStreamsMap_;
@@ -465,41 +467,42 @@ inline std::string reduceOpToString(c10d::ReduceOp op) {
465467
// Since the current profiler trace support for XCCL is unclear, wrap
466468
// `RECORD_PARAM_COMMS_DATA` and output parameters as debug logs.
467469
// export TORCH_CPP_LOG_LEVEL=INFO
468-
#define RECORD_PARAM_COMMS_DATA_WITH_LOG( \
469-
seq, \
470-
pg_name_tuple, \
471-
inputTensors, \
472-
outputTensors, \
473-
rank, \
474-
collective_name, \
475-
inNelems, \
476-
outNelems, \
477-
dType, \
478-
inSplitSizes, \
479-
outSplitSizes, \
480-
globalRankStart, \
481-
globalRankStride, \
482-
worldSize) \
483-
do { \
484-
LOG(INFO) << "collective_name: " << collective_name \
485-
<< ", inNelems: " << inNelems << ", outNelems: " << outNelems \
486-
<< ", dType: " << dType << ", root/src rank: " << rank \
487-
<< ", worldSize: " << worldSize; \
488-
RECORD_PARAM_COMMS_DATA( \
489-
seq, \
490-
pg_name_tuple, \
491-
inputTensors, \
492-
outputTensors, \
493-
rank, \
494-
collective_name, \
495-
inNelems, \
496-
outNelems, \
497-
dType, \
498-
inSplitSizes, \
499-
outSplitSizes, \
500-
globalRankStart, \
501-
globalRankStride, \
502-
worldSize); \
470+
#define RECORD_PARAM_COMMS_DATA_WITH_LOG( \
471+
seq, \
472+
pg_name_tuple, \
473+
inputTensors, \
474+
outputTensors, \
475+
rank, \
476+
collective_name, \
477+
inNelems, \
478+
outNelems, \
479+
dType, \
480+
inSplitSizes, \
481+
outSplitSizes, \
482+
globalRankStart, \
483+
globalRankStride, \
484+
worldSize, \
485+
async_op) \
486+
do { \
487+
LOG(INFO) << std::boolalpha << "collective_name: " << collective_name \
488+
<< ", inNelems: " << inNelems << ", outNelems: " << outNelems \
489+
<< ", dType: " << dType << ", root/src rank: " << rank \
490+
<< ", worldSize: " << worldSize << ", async_op: " << async_op; \
491+
RECORD_PARAM_COMMS_DATA( \
492+
seq, \
493+
pg_name_tuple, \
494+
inputTensors, \
495+
outputTensors, \
496+
rank, \
497+
collective_name, \
498+
inNelems, \
499+
outNelems, \
500+
dType, \
501+
inSplitSizes, \
502+
outSplitSizes, \
503+
globalRankStart, \
504+
globalRankStride, \
505+
worldSize); \
503506
} while (0)
504507
} // namespace
505508
#endif // USE_C10D_XCCL

0 commit comments

Comments
 (0)