Skip to content

Commit 1bfe95d

Browse files
Liu Kemeta-codesync[bot]
authored andcommitted
integrate DDA AllToAll to rcclx
Summary: - Integrate DDA AllToAll into RCCLX - rccl-tests performance of DDA vs. RCCLX on MI300x and MI350x - tune DdaAllToAllMaxBytes based on MI300x perf Reviewed By: cenzhaometa Differential Revision: D86030871 fbshipit-source-id: 3ad5903e22b028eeb2dddf56ae8104e243d2a066
1 parent 6ed088b commit 1bfe95d

File tree

9 files changed

+379
-5
lines changed

9 files changed

+379
-5
lines changed

comms/common/algorithms/AlgoFactory.cu

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ AlgoFactory::AlgoFactory(
1414
int maxBlocks,
1515
const AllReduceOptions& allReduceOpts,
1616
const AllGatherOptions& allGatherOpts,
17-
const ReduceScatterOptions& reduceScatterOpts) {
17+
const ReduceScatterOptions& reduceScatterOpts,
18+
const AllToAllOptions& allToAllOpts) {
1819
if (allReduceOpts.enableDda || allGatherOpts.enableDda ||
19-
reduceScatterOpts.enableDda) {
20+
reduceScatterOpts.enableDda || allToAllOpts.enableDda) {
2021
XLOG(DBG)
21-
<< "Initializing AllReduceAlgoManager / AllGatherAlgoManager / ReduceScatterAlgoManager";
22+
<< "Initializing AllReduceAlgoManager / AllGatherAlgoManager / ReduceScatterAlgoManager / AllToAllAlgoManager";
2223

2324
for (int i = 0; i < nRanks; ++i) {
2425
if (i == selfRank) {
@@ -64,6 +65,17 @@ AlgoFactory::AlgoFactory(
6465
reduceScatterOpts.ddaMaxThresholdBytes);
6566
XLOG(DBG) << "Successfully initialized ReduceScatterAlgoManager";
6667
}
68+
69+
if (allToAllOpts.enableDda) {
70+
allToAllMgr_ = std::make_unique<AllToAllAlgoManager>(
71+
bootstrap,
72+
nRanks,
73+
selfRank,
74+
maxBlocks,
75+
allToAllOpts.ddaSendbufSizeBytes,
76+
allToAllOpts.ddaMaxThresholdBytes);
77+
XLOG(DBG) << "Successfully initialized AllToAllAlgoManager";
78+
}
6779
}
6880

6981
} // namespace meta::comms

comms/common/algorithms/AlgoFactory.cuh

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "comms/common/algorithms/all_gather/AllGatherAlgoManager.h"
66
#include "comms/common/algorithms/all_reduce/AllReduceAlgoManager.h"
7+
#include "comms/common/algorithms/all_to_all/AllToAllAlgoManager.h"
78
#include "comms/common/algorithms/reduce_scatter/ReduceScatterAlgoManager.h"
89
#include "comms/ctran/interfaces/IBootstrap.h" // @manual
910
#include "comms/utils/commSpecs.h"
@@ -14,6 +15,7 @@ namespace meta::comms {
1415
class AlgoManagerAllReduce;
1516
class AlgoManagerAllGather;
1617
class AlgoManagerReduceScatter;
18+
class AlgoManagerAllToAll;
1719

1820
/**
1921
* per communicator per rank Algorithm factory that
@@ -46,14 +48,22 @@ class AlgoFactory {
4648
// DDA will be used
4749
int ddaMaxThresholdBytes{0};
4850
};
51+
struct AllToAllOptions {
52+
bool enableDda{false};
53+
int ddaSendbufSizeBytes{0};
54+
// If msg size is not larger than the threshold,
55+
// DDA will be used
56+
int ddaMaxThresholdBytes{0};
57+
};
4958
AlgoFactory(
5059
std::shared_ptr<ctran::bootstrap::IBootstrap> bootstrap,
5160
int nRanks,
5261
int selfRank,
5362
int maxBlocks,
5463
const AllReduceOptions& allReduceOpts,
5564
const AllGatherOptions& allGatherOpts,
56-
const ReduceScatterOptions& reduceScatterOpts);
65+
const ReduceScatterOptions& reduceScatterOpts,
66+
const AllToAllOptions& allToAllOpts);
5767

5868
std::unique_ptr<AlgoAllReduce> getAllReduceAlgo(
5969
const void* sendbuff,
@@ -97,9 +107,24 @@ class AlgoFactory {
97107
sendbuff, recvbuff, count, datatype, stream, acc);
98108
}
99109

110+
std::unique_ptr<AlgoAllToAll> getAllToAllAlgo(
111+
const void* sendbuff,
112+
void* recvbuff,
113+
size_t count,
114+
commDataType_t datatype,
115+
cudaStream_t stream,
116+
const void* acc = nullptr) {
117+
if (allToAllMgr_ == nullptr) {
118+
return nullptr;
119+
}
120+
return allToAllMgr_->getAllToAllAlgo(
121+
sendbuff, recvbuff, count, datatype, stream, acc);
122+
}
123+
100124
private:
101125
std::unique_ptr<AllReduceAlgoManager> allReduceMgr_{nullptr};
102126
std::unique_ptr<AllGatherAlgoManager> allGatherMgr_{nullptr};
103127
std::unique_ptr<ReduceScatterAlgoManager> reduceScatterMgr_{nullptr};
128+
std::unique_ptr<AllToAllAlgoManager> allToAllMgr_{nullptr};
104129
};
105130
} // namespace meta::comms
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include "comms/common/algorithms/all_to_all/AlgoAllToAll.cuh"
4+
#include "comms/utils/checks.h"
5+
6+
namespace meta::comms {
7+
8+
AlgoAllToAll::AlgoAllToAll(
9+
const void* sendbuff,
10+
void** allRankDdaSendbuffs,
11+
void* recvbuff,
12+
size_t count,
13+
commDataType_t datatype,
14+
cudaStream_t stream,
15+
int nRanks,
16+
int selfRank,
17+
int maxBlocks,
18+
IpcGpuBarrier* barrier,
19+
const void* acc)
20+
: sendbuff_(sendbuff),
21+
allRankDdaSendbuffs_(allRankDdaSendbuffs),
22+
recvbuff_(recvbuff),
23+
count_(count),
24+
datatype_(datatype),
25+
stream_(stream),
26+
nRanks_(nRanks),
27+
selfRank_(selfRank),
28+
maxBlocks_(maxBlocks),
29+
barrier_(barrier),
30+
acc_(acc) {}
31+
32+
void AlgoAllToAllDdaIpc::allToAll() {
33+
TYPED_CALL(datatype_, launchKernel);
34+
}
35+
36+
} // namespace meta::comms
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
#include "comms/common/IpcGpuBarrier.cuh"
8+
#include "comms/common/algorithms/AlgoUtils.h"
9+
#include "comms/common/algorithms/all_to_all/all_to_all_dda.cuh"
10+
#include "comms/utils/checks.h"
11+
#include "comms/utils/commSpecs.h"
12+
13+
namespace meta::comms {
14+
15+
/**
16+
* This class defines common interface for all AllToAll Algorithms
17+
* subclasses are expected to provide actual implementation
18+
*/
19+
class AlgoAllToAll {
20+
public:
21+
// NOTE: acc is not used for all-to-all
22+
AlgoAllToAll(
23+
const void* sendbuff,
24+
void** allRankDdaSendbuffs,
25+
void* recvbuff,
26+
size_t count,
27+
commDataType_t datatype,
28+
cudaStream_t stream,
29+
int nRanks,
30+
int selfRank,
31+
int maxBlocks,
32+
IpcGpuBarrier* barrier,
33+
const void* acc);
34+
35+
virtual ~AlgoAllToAll() = default;
36+
37+
virtual void allToAll() = 0;
38+
39+
protected:
40+
const void* sendbuff_{nullptr};
41+
void** allRankDdaSendbuffs_{nullptr};
42+
void* recvbuff_{nullptr};
43+
size_t count_{0};
44+
commDataType_t datatype_{commBfloat16};
45+
cudaStream_t stream_{nullptr};
46+
int nRanks_{0};
47+
int selfRank_{0};
48+
const size_t maxBlocks_{0};
49+
IpcGpuBarrier* barrier_;
50+
const void* acc_{nullptr};
51+
};
52+
53+
class AlgoAllToAllDdaIpc : public AlgoAllToAll {
54+
public:
55+
using AlgoAllToAll::AlgoAllToAll;
56+
57+
void allToAll() override;
58+
59+
private:
60+
template <typename T>
61+
void launchKernel() {
62+
const void* func = nullptr;
63+
64+
ASSIGN_FUNC_NRANKS(func, ddaAllToAllIpc, nRanks_, false /* hasAcc */);
65+
66+
auto gridBlock =
67+
getGridAndBlockDims(nRanks_ * count_, datatype_, maxBlocks_);
68+
const auto& grid = gridBlock.first;
69+
const auto& block = gridBlock.second;
70+
71+
void* args[] = {
72+
&allRankDdaSendbuffs_,
73+
&recvbuff_,
74+
&count_,
75+
&sendbuff_,
76+
&selfRank_,
77+
barrier_,
78+
&acc_};
79+
CUDA_CHECK(cudaLaunchKernel(func, grid, block, args, 0, stream_));
80+
}
81+
};
82+
83+
} // namespace meta::comms
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include "comms/common/algorithms/all_to_all/AllToAllAlgoManager.h"
4+
5+
namespace meta::comms {
6+
7+
AllToAllAlgoManager::AllToAllAlgoManager(
8+
std::shared_ptr<ctran::bootstrap::IBootstrap> bootstrap,
9+
int nRanks,
10+
int selfRank,
11+
int maxBlocks,
12+
int ddaSendbufSizeBytes,
13+
int ddaMaxThresholdBytes)
14+
: nRanks_(nRanks),
15+
selfRank_(selfRank),
16+
maxBlocks_(maxBlocks),
17+
ddaSendbufSizeBytes_(ddaSendbufSizeBytes),
18+
ddaMaxThresholdBytes_(ddaMaxThresholdBytes) {
19+
auto [barrierResources, barrier] =
20+
IpcGpuBarrier::mallocAndInit(nRanks_, maxBlocks_, selfRank_, bootstrap);
21+
barrierResources_ = std::move(barrierResources);
22+
barrier_ = barrier;
23+
24+
ddaSendbuf_ = std::make_unique<DeviceBuffer>(ddaSendbufSizeBytes_ * nRanks_);
25+
memHandler_ = std::make_unique<IpcMemHandler>(bootstrap, selfRank_, nRanks_);
26+
memHandler_->addSelfDeviceMemPtr(ddaSendbuf_->get());
27+
memHandler_->exchangeMemPtrs();
28+
29+
std::vector<void*> ipcSendbufs(nRanks_);
30+
for (int i = 0; i < nRanks_; ++i) {
31+
ipcSendbufs[i] = memHandler_->getPeerDeviceMemPtr(i);
32+
}
33+
34+
allRankDdaSendbuffs_ =
35+
std::make_unique<DeviceBuffer>(sizeof(void*) * nRanks_);
36+
CUDA_CHECK(cudaMemcpy(
37+
allRankDdaSendbuffs_->get(),
38+
ipcSendbufs.data(),
39+
sizeof(void*) * nRanks_,
40+
cudaMemcpyDefault));
41+
XLOG(DBG) << "Successfully initialized AllToAllAlgoManager";
42+
}
43+
44+
std::unique_ptr<AlgoAllToAll> AllToAllAlgoManager::getAllToAllAlgo(
45+
const void* sendbuff,
46+
void* recvbuff,
47+
size_t count,
48+
commDataType_t datatype,
49+
cudaStream_t stream,
50+
const void* acc) {
51+
if ((count * commTypeSize(datatype)) > ddaSendbufSizeBytes_) {
52+
// msg size must fit into the dda sendbuf
53+
XLOG(DBG) << "Not using custom all-to-all algo because message size "
54+
<< count * commTypeSize(datatype)
55+
<< " is larger than ddaSendbufSizeBytes " << ddaSendbufSizeBytes_;
56+
return nullptr;
57+
}
58+
59+
if (((uintptr_t)sendbuff % 16) || ((uintptr_t)recvbuff % 16) ||
60+
((count * commTypeSize(datatype)) % 16)) {
61+
// 16 byte alignment as we do 16-byte loads in DDA kernel
62+
XLOG(DBG) << "Not using custom all-to-all algo because send/recv buff "
63+
"or msg size is not 16-byte aligned";
64+
return nullptr;
65+
}
66+
67+
if (datatype != commBfloat16 && datatype != commFloat16) {
68+
// we currently only support bf16 and half
69+
XLOG(DBG)
70+
<< "Not using custom all-to-all algo because cudaDataType_t datatype "
71+
<< static_cast<int>(datatype) << " is not supported";
72+
return nullptr;
73+
}
74+
75+
std::unique_ptr<AlgoAllToAll> algo;
76+
if ((count * commTypeSize(datatype)) > ddaMaxThresholdBytes_) {
77+
XLOG(DBG) << "Not using custom all-to-all algo because msg size "
78+
<< count * commTypeSize(datatype)
79+
<< " is larger than DDA algo threshold " << ddaMaxThresholdBytes_;
80+
return nullptr;
81+
} else {
82+
if ((count * commTypeSize(datatype)) % 16) {
83+
XLOG(DBG) << "Not using DDA all-to-all algo because send/recv buff "
84+
"or msg size is not 16-byte aligned for each rank";
85+
return nullptr;
86+
}
87+
algo = std::make_unique<AlgoAllToAllDdaIpc>(
88+
sendbuff,
89+
reinterpret_cast<void**>(allRankDdaSendbuffs_->get()),
90+
recvbuff,
91+
count,
92+
datatype,
93+
stream,
94+
nRanks_,
95+
selfRank_,
96+
maxBlocks_,
97+
&barrier_,
98+
acc);
99+
}
100+
return algo;
101+
}
102+
103+
} // namespace meta::comms
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#include "comms/common/IpcGpuBarrier.cuh"
6+
#include "comms/common/algorithms/all_to_all/AlgoAllToAll.cuh"
7+
#include "comms/ctran/interfaces/IBootstrap.h" // @manual
8+
#include "comms/utils/CudaRAII.h"
9+
#include "comms/utils/commSpecs.h"
10+
11+
namespace meta::comms {
12+
13+
class AllToAllAlgoManager {
14+
public:
15+
AllToAllAlgoManager(
16+
std::shared_ptr<ctran::bootstrap::IBootstrap> bootstrap,
17+
int nRanks,
18+
int selfRank,
19+
int maxBlocks,
20+
int ddaSendbufSizeBytes,
21+
int ddaMaxThresholdBytes);
22+
AllToAllAlgoManager(const AllToAllAlgoManager&) = delete;
23+
AllToAllAlgoManager(AllToAllAlgoManager&&) = delete;
24+
25+
std::unique_ptr<AlgoAllToAll> getAllToAllAlgo(
26+
const void* sendbuff,
27+
void* recvbuff,
28+
size_t count,
29+
commDataType_t datatype,
30+
cudaStream_t stream,
31+
const void* acc);
32+
33+
private:
34+
int nRanks_{0};
35+
int selfRank_{-1};
36+
int maxBlocks_{0};
37+
int ddaSendbufSizeBytes_{0};
38+
int ddaMaxThresholdBytes_{0};
39+
std::unique_ptr<IpcGpuBarrierResources> barrierResources_;
40+
IpcGpuBarrier barrier_;
41+
std::unique_ptr<DeviceBuffer> ddaSendbuf_;
42+
std::unique_ptr<IpcMemHandler> memHandler_;
43+
// arrary of void* (all ranks' ipc enabled sendbuf) in device memory
44+
std::unique_ptr<DeviceBuffer> allRankDdaSendbuffs_;
45+
};
46+
47+
} // namespace meta::comms

0 commit comments

Comments
 (0)