@@ -755,7 +755,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::send(
755
755
std::vector<int64_t >(), // outSplitSizes
756
756
-1 , // globalRankStart
757
757
-1 , // globalRankStride
758
- this ->getSize ()); // worldSize
758
+ this ->getSize (), // worldSize
759
+ " N/A" ); // async_op
759
760
760
761
auto ret = pointToPoint (
761
762
tensor,
@@ -804,7 +805,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::recv(
804
805
std::vector<int64_t >(), // outSplitSizes
805
806
-1 , // globalRankStart
806
807
-1 , // globalRankStride
807
- this ->getSize ()); // worldSize
808
+ this ->getSize (), // worldSize
809
+ " N/A" ); // async_op
808
810
809
811
auto ret = pointToPoint (
810
812
tensor,
@@ -889,7 +891,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::gather(
889
891
std::vector<int64_t >(), // outSplitSize
890
892
-1 , // globalRankStart
891
893
-1 , // globalRankStride
892
- this ->getSize ()); // worldSize
894
+ this ->getSize (), // worldSize
895
+ opts.asyncOp ); // async_op
893
896
894
897
auto inputs = std::vector<at::Tensor>{inputTensor};
895
898
return collective (
@@ -1003,7 +1006,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
1003
1006
std::vector<int64_t >(), // outSplitSize
1004
1007
-1 , // globalRankStart
1005
1008
-1 , // globalRankStride
1006
- this ->getSize ()); // worldSize
1009
+ this ->getSize (), // worldSize
1010
+ opts.asyncOp ); // async_op
1007
1011
1008
1012
const auto root = opts.rootRank ;
1009
1013
@@ -1131,7 +1135,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
1131
1135
std::vector<int64_t >(), // outSplitSizes
1132
1136
-1 , // globalRankStart
1133
1137
-1 , // globalRankStride
1134
- size_); // worldSize
1138
+ size_, // worldSize
1139
+ opts.asyncOp ); // async_op
1135
1140
1136
1141
return allreduce_impl (tensor, " xccl:all_reduce" , opts);
1137
1142
}
@@ -1157,7 +1162,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
1157
1162
std::vector<int64_t >(), // outSplitSizes
1158
1163
-1 , // globalRankStart
1159
1164
-1 , // globalRankStride
1160
- this ->getSize ()); // worldSize
1165
+ this ->getSize (), // worldSize
1166
+ opts.asyncOp ); // async_op
1161
1167
1162
1168
return collectiveCoalesced (
1163
1169
tensors,
@@ -1219,7 +1225,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
1219
1225
std::vector<int64_t >(), // outSplitSizes
1220
1226
-1 , // globalRankStart
1221
1227
-1 , // globalRankStride
1222
- this ->getSize ()); // worldSize
1228
+ this ->getSize (), // worldSize
1229
+ opts.asyncOp ); // async_op
1223
1230
1224
1231
const auto root = opts.rootRank + opts.rootTensor ;
1225
1232
@@ -1310,7 +1317,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
1310
1317
std::vector<int64_t >(), // outSplitSizes
1311
1318
-1 , // globalRankStart
1312
1319
-1 , // globalRankStride
1313
- this ->getSize ()); // worldSize
1320
+ this ->getSize (), // worldSize
1321
+ opts.asyncOp ); // async_op
1314
1322
1315
1323
return collective (
1316
1324
tensor,
@@ -1419,7 +1427,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather(
1419
1427
std::vector<int64_t >(), // outSplitSize
1420
1428
-1 , // globalRankStart
1421
1429
-1 , // globalRankStride
1422
- this ->getSize ()); // worldSize
1430
+ this ->getSize (), // worldSize
1431
+ opts.asyncOp ); // async_op
1423
1432
1424
1433
bool same_size = checkSameSize (outputTensors_);
1425
1434
if (same_size) {
@@ -1506,7 +1515,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_allgather_base(
1506
1515
std::vector<int64_t >(), // outSplitSize
1507
1516
-1 , // globalRankStart
1508
1517
-1 , // globalRankStride
1509
- this ->getSize ()); // worldSize
1518
+ this ->getSize (), // worldSize
1519
+ opts.asyncOp ); // async_op
1510
1520
1511
1521
return collective (
1512
1522
input_tensor,
@@ -1552,7 +1562,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather_into_tensor_coalesced(
1552
1562
std::vector<int64_t >(), // outSplitSizes
1553
1563
-1 , // globalRankStart
1554
1564
-1 , // globalRankStride
1555
- this ->getSize ()); // worldSize
1565
+ this ->getSize (), // worldSize
1566
+ opts.asyncOp ); // async_op
1556
1567
1557
1568
return collectiveCoalesced (
1558
1569
inputs,
@@ -1603,7 +1614,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
1603
1614
std::vector<int64_t >(), // outSplitSizes
1604
1615
-1 , // globalRankStart
1605
1616
-1 , // globalRankStride
1606
- this ->getSize ()); // worldSize
1617
+ this ->getSize (), // worldSize
1618
+ opts.asyncOp ); // async_op
1607
1619
1608
1620
bool same_size = checkSameSize (inputTensors_);
1609
1621
if (same_size) {
@@ -1700,7 +1712,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
1700
1712
std::vector<int64_t >(), // outSplitSizes
1701
1713
-1 , // globalRankStart
1702
1714
-1 , // globalRankStride
1703
- this ->getSize ()); // worldSize
1715
+ this ->getSize (), // worldSize
1716
+ opts.asyncOp ); // async_op
1704
1717
1705
1718
return collective (
1706
1719
inputTensor,
@@ -1740,7 +1753,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
1740
1753
std::vector<at::Tensor>& outputs,
1741
1754
std::vector<at::Tensor>& inputs,
1742
1755
const ReduceScatterOptions& opts) {
1743
-
1744
1756
RECORD_PARAM_COMMS_DATA_WITH_LOG (
1745
1757
std::make_tuple (
1746
1758
static_cast <int64_t >(seqCollective_) + 1 ,
@@ -1758,7 +1770,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
1758
1770
std::vector<int64_t >(), // outSplitSizes
1759
1771
-1 , // globalRankStart
1760
1772
-1 , // globalRankStride
1761
- this ->getSize ()); // worldSize
1773
+ this ->getSize (), // worldSize
1774
+ opts.asyncOp ); // async_op
1762
1775
1763
1776
return collectiveCoalesced (
1764
1777
inputs,
@@ -1794,6 +1807,25 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
1794
1807
" xccl:reduce_scatter_tensor_coalesced" );
1795
1808
}
1796
1809
1810
+ c10::DeviceIndex ProcessGroupXCCL::guessDeviceId () const {
1811
+ if (getBoundDeviceId ().has_value ()) {
1812
+ return getBoundDeviceId ().value ().index ();
1813
+ } else if (!usedDeviceIdxs_.empty ()) {
1814
+ return *usedDeviceIdxs_.begin ();
1815
+ }
1816
+ int devIdx =
1817
+ static_cast <int16_t >(rank_ % at::detail::getXPUHooks ().getNumGPUs ());
1818
+ LOG (WARNING)
1819
+ << logPrefix ()
1820
+ << c10::str (
1821
+ " using GPU " ,
1822
+ devIdx,
1823
+ " as device used by this process is currently unknown. " ,
1824
+ " This can potentially cause a hang if this rank to GPU mapping is incorrect. " ,
1825
+ " You can specify device_id in init_process_group() to force use of a particular device." );
1826
+ return static_cast <c10::DeviceIndex>(devIdx);
1827
+ }
1828
+
1797
1829
c10::intrusive_ptr<Work> ProcessGroupXCCL::barrier (const BarrierOptions& opts) {
1798
1830
RECORD_PARAM_COMMS (
1799
1831
static_cast <int >(
@@ -1810,18 +1842,13 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::barrier(const BarrierOptions& opts) {
1810
1842
-1 , // globalRankStride
1811
1843
this ->getSize ()); // worldSize
1812
1844
// Device to use for barrier
1813
- int barDevIdx = -1 ;
1845
+ c10::DeviceIndex barDevIdx = -1 ;
1814
1846
1815
1847
// See nccl barrier comments
1816
1848
if (!opts.device_ids .empty ()) {
1817
- barDevIdx = opts.device_ids [0 ];
1818
- } else if (getBoundDeviceId ()) {
1819
- barDevIdx = (*getBoundDeviceId ()).index ();
1820
- } else if (!usedDeviceIdxs_.empty ()) {
1821
- barDevIdx = *usedDeviceIdxs_.begin ();
1849
+ barDevIdx = static_cast <c10::DeviceIndex>(opts.device_ids [0 ]);
1822
1850
} else {
1823
- barDevIdx =
1824
- static_cast <int16_t >(rank_ % at::detail::getXPUHooks ().getNumGPUs ());
1851
+ barDevIdx = guessDeviceId ();
1825
1852
}
1826
1853
1827
1854
TORCH_CHECK_WITH (
@@ -1833,12 +1860,20 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::barrier(const BarrierOptions& opts) {
1833
1860
at::Tensor barrierTensor =
1834
1861
at::zeros ({1 }, at::TensorOptions ().device (barDevice).dtype (at::kFloat ));
1835
1862
1836
- auto work = allreduce_impl (barrierTensor, " xccl:all_reduce_barrier" );
1863
+ AllreduceOptions arOpts = AllreduceOptions ();
1864
+ arOpts.asyncOp = opts.asyncOp ;
1865
+ auto work = allreduce_impl (barrierTensor, " xccl:all_reduce_barrier" , arOpts);
1866
+
1867
+ if (opts.asyncOp ) {
1868
+ auto xcclWork = dynamic_cast <ProcessGroupXCCL::WorkXCCL*>(work.get ());
1869
+ TORCH_CHECK (xcclWork);
1870
+ xcclWork->isBarrierOp_ = true ;
1871
+ return work;
1872
+ }
1837
1873
1838
- auto xcclWork = dynamic_cast <ProcessGroupXCCL::WorkXCCL*>(work.get ());
1839
- TORCH_CHECK (xcclWork);
1840
- xcclWork->isBarrierOp_ = true ;
1841
- return work;
1874
+ auto currentStream = at::xpu::getCurrentXPUStream (barDevIdx);
1875
+ currentStream.synchronize ();
1876
+ return nullptr ;
1842
1877
}
1843
1878
1844
1879
c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base (
@@ -1866,7 +1901,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
1866
1901
std::vector<int64_t >(), // outSplitSizes
1867
1902
-1 , // globalRankStart
1868
1903
-1 , // globalRankStride
1869
- this ->getSize ()); // worldSize
1904
+ this ->getSize (), // worldSize
1905
+ opts.asyncOp ); // async_op
1906
+
1870
1907
TORCH_CHECK (
1871
1908
outputTensor.numel () == inputTensor.numel () &&
1872
1909
outputTensor.scalar_type () == inputTensor.scalar_type (),
@@ -1915,7 +1952,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
1915
1952
outputSplitSizes, // outSplitSizes
1916
1953
-1 , // globalRankStart
1917
1954
-1 , // globalRankStride
1918
- this ->getSize ()); // worldSize
1955
+ this ->getSize (), // worldSize
1956
+ opts.asyncOp ); // async_op
1919
1957
1920
1958
return collective (
1921
1959
inputTensor,
@@ -1991,7 +2029,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
1991
2029
std::vector<int64_t >(), // outSplitSizes
1992
2030
-1 , // globalRankStart
1993
2031
-1 , // globalRankStride
1994
- this ->getSize ()); // worldSize
2032
+ this ->getSize (), // worldSize
2033
+ opts.asyncOp ); // async_op
1995
2034
1996
2035
return collective (
1997
2036
inputTensors,
0 commit comments