Skip to content

Commit e569982

Browse files
Liu Kemeta-codesync[bot]
authored andcommitted
share internal-buffer and barrier across all collectives in AlgoFactory
Summary: In this diff: - refactor the duplicated **internal-buffer and barrier** in AllReduceAlgoManager, AllGatherAlgoManager, ReduceScatterAlgoManager, AllToAllAlgoManager to be shared in **AlgoFactory** - remove "acc" in AllGather, ReduceScatter, AllToAll. Currently, only AllReduceWithBias API. No major performance regression between the previous duplicated and this shared implementation, details in table below https://docs.google.com/spreadsheets/d/1YcJTbc3Tjk8qmbXb4sOoCekB0hyMiwP_A0ptf1nCoVc/edit?usp=sharing Reviewed By: cenzhaometa Differential Revision: D86362186 fbshipit-source-id: 32d8ffa8f1e46cd1bb028ce3adae701f550c07e2
1 parent 1bfe95d commit e569982

24 files changed

+189
-285
lines changed

comms/common/algorithms/AlgoFactory.cu

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@ AlgoFactory::AlgoFactory(
1212
int nRanks,
1313
int selfRank,
1414
int maxBlocks,
15+
int ddaSendbufSizeBytes,
1516
const AllReduceOptions& allReduceOpts,
1617
const AllGatherOptions& allGatherOpts,
1718
const ReduceScatterOptions& reduceScatterOpts,
18-
const AllToAllOptions& allToAllOpts) {
19+
const AllToAllOptions& allToAllOpts)
20+
: nRanks_(nRanks),
21+
selfRank_(selfRank),
22+
maxBlocks_(maxBlocks),
23+
ddaSendbufSizeBytes_(ddaSendbufSizeBytes) {
1924
if (allReduceOpts.enableDda || allGatherOpts.enableDda ||
2025
reduceScatterOpts.enableDda || allToAllOpts.enableDda) {
2126
XLOG(DBG)
22-
<< "Initializing AllReduceAlgoManager / AllGatherAlgoManager / ReduceScatterAlgoManager / AllToAllAlgoManager";
27+
<< "Initializing AllReduce / AllGather / ReduceScatter / AllToAll AlgoManager";
2328

2429
for (int i = 0; i < nRanks; ++i) {
2530
if (i == selfRank) {
@@ -30,51 +35,75 @@ AlgoFactory::AlgoFactory(
3035
CUDA_CHECK(e);
3136
}
3237
}
38+
39+
auto [barrierResources, barrier] =
40+
IpcGpuBarrier::mallocAndInit(nRanks_, maxBlocks_, selfRank_, bootstrap);
41+
barrierResources_ = std::move(barrierResources);
42+
barrier_ = barrier;
43+
44+
ddaSendbuf_ = std::make_unique<DeviceBuffer>(ddaSendbufSizeBytes_);
45+
memHandler_ =
46+
std::make_unique<IpcMemHandler>(bootstrap, selfRank_, nRanks_);
47+
memHandler_->addSelfDeviceMemPtr(ddaSendbuf_->get());
48+
memHandler_->exchangeMemPtrs();
49+
50+
std::vector<void*> ipcSendbufs(nRanks_);
51+
for (int i = 0; i < nRanks_; ++i) {
52+
ipcSendbufs[i] = memHandler_->getPeerDeviceMemPtr(i);
53+
}
54+
55+
allRankDdaSendbuffs_ =
56+
std::make_unique<DeviceBuffer>(sizeof(void*) * nRanks_);
57+
CUDA_CHECK(cudaMemcpy(
58+
allRankDdaSendbuffs_->get(),
59+
ipcSendbufs.data(),
60+
sizeof(void*) * nRanks_,
61+
cudaMemcpyDefault));
3362
}
3463

3564
if (allReduceOpts.enableDda) {
3665
allReduceMgr_ = std::make_unique<AllReduceAlgoManager>(
37-
bootstrap,
3866
nRanks,
3967
selfRank,
4068
maxBlocks,
41-
allReduceOpts.ddaSendbufSizeBytes,
69+
ddaSendbufSizeBytes,
4270
allReduceOpts.ddaFlatMaxThresholdBytes,
43-
allReduceOpts.ddaTreeMaxThresholdBytes);
44-
XLOG(DBG) << "Successfully initialized AllReduceAlgoManager";
71+
allReduceOpts.ddaTreeMaxThresholdBytes,
72+
reinterpret_cast<void**>(allRankDdaSendbuffs_->get()),
73+
&barrier_);
4574
}
4675

4776
if (allGatherOpts.enableDda) {
4877
allGatherMgr_ = std::make_unique<AllGatherAlgoManager>(
49-
bootstrap,
5078
nRanks,
5179
selfRank,
5280
maxBlocks,
53-
allGatherOpts.ddaSendbufSizeBytes,
54-
allGatherOpts.ddaMaxThresholdBytes);
55-
XLOG(DBG) << "Successfully initialized AllGatherAlgoManager";
81+
ddaSendbufSizeBytes,
82+
allGatherOpts.ddaMaxThresholdBytes,
83+
reinterpret_cast<void**>(allRankDdaSendbuffs_->get()),
84+
&barrier_);
5685
}
5786

5887
if (reduceScatterOpts.enableDda) {
5988
reduceScatterMgr_ = std::make_unique<ReduceScatterAlgoManager>(
60-
bootstrap,
6189
nRanks,
6290
selfRank,
6391
maxBlocks,
64-
reduceScatterOpts.ddaSendbufSizeBytes,
65-
reduceScatterOpts.ddaMaxThresholdBytes);
66-
XLOG(DBG) << "Successfully initialized ReduceScatterAlgoManager";
92+
ddaSendbufSizeBytes,
93+
reduceScatterOpts.ddaMaxThresholdBytes,
94+
reinterpret_cast<void**>(allRankDdaSendbuffs_->get()),
95+
&barrier_);
6796
}
6897

6998
if (allToAllOpts.enableDda) {
7099
allToAllMgr_ = std::make_unique<AllToAllAlgoManager>(
71-
bootstrap,
72100
nRanks,
73101
selfRank,
74102
maxBlocks,
75-
allToAllOpts.ddaSendbufSizeBytes,
76-
allToAllOpts.ddaMaxThresholdBytes);
77-
XLOG(DBG) << "Successfully initialized AllToAllAlgoManager";
103+
ddaSendbufSizeBytes,
104+
allToAllOpts.ddaMaxThresholdBytes,
105+
reinterpret_cast<void**>(allRankDdaSendbuffs_->get()),
106+
&barrier_);
78107
}
79108
}
80109

comms/common/algorithms/AlgoFactory.cuh

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
#pragma once
44

5+
#include "comms/common/IpcGpuBarrier.cuh"
56
#include "comms/common/algorithms/all_gather/AllGatherAlgoManager.h"
67
#include "comms/common/algorithms/all_reduce/AllReduceAlgoManager.h"
78
#include "comms/common/algorithms/all_to_all/AllToAllAlgoManager.h"
89
#include "comms/common/algorithms/reduce_scatter/ReduceScatterAlgoManager.h"
910
#include "comms/ctran/interfaces/IBootstrap.h" // @manual
11+
#include "comms/utils/CudaRAII.h"
1012
#include "comms/utils/commSpecs.h"
1113

1214
namespace meta::comms {
@@ -26,40 +28,41 @@ class AlgoFactory {
2628
public:
2729
struct AllReduceOptions {
2830
bool enableDda{false};
29-
int ddaSendbufSizeBytes{0};
3031
// If msg size is not larger than the threshold,
3132
// flat (one-shot) DDA will be used
3233
int ddaFlatMaxThresholdBytes{0};
3334
// If msg size is not larger than the threshold,
3435
// tree (two-shot) DDA will be used
3536
int ddaTreeMaxThresholdBytes{0};
3637
};
38+
3739
struct AllGatherOptions {
3840
bool enableDda{false};
39-
int ddaSendbufSizeBytes{0};
4041
// If msg size is not larger than the threshold,
4142
// DDA will be used
4243
int ddaMaxThresholdBytes{0};
4344
};
45+
4446
struct ReduceScatterOptions {
4547
bool enableDda{false};
46-
int ddaSendbufSizeBytes{0};
4748
// If msg size is not larger than the threshold,
4849
// DDA will be used
4950
int ddaMaxThresholdBytes{0};
5051
};
52+
5153
struct AllToAllOptions {
5254
bool enableDda{false};
53-
int ddaSendbufSizeBytes{0};
5455
// If msg size is not larger than the threshold,
5556
// DDA will be used
5657
int ddaMaxThresholdBytes{0};
5758
};
59+
5860
AlgoFactory(
5961
std::shared_ptr<ctran::bootstrap::IBootstrap> bootstrap,
6062
int nRanks,
6163
int selfRank,
6264
int maxBlocks,
65+
int ddaSendbufSizeBytes,
6366
const AllReduceOptions& allReduceOpts,
6467
const AllGatherOptions& allGatherOpts,
6568
const ReduceScatterOptions& reduceScatterOpts,
@@ -84,44 +87,53 @@ class AlgoFactory {
8487
void* recvbuff,
8588
size_t count,
8689
commDataType_t datatype,
87-
cudaStream_t stream,
88-
const void* acc = nullptr) {
90+
cudaStream_t stream) {
8991
if (allGatherMgr_ == nullptr) {
9092
return nullptr;
9193
}
9294
return allGatherMgr_->getAllGatherAlgo(
93-
sendbuff, recvbuff, count, datatype, stream, acc);
95+
sendbuff, recvbuff, count, datatype, stream);
9496
}
9597

9698
std::unique_ptr<AlgoReduceScatter> getReduceScatterAlgo(
9799
const void* sendbuff,
98100
void* recvbuff,
99101
size_t count,
100102
commDataType_t datatype,
101-
cudaStream_t stream,
102-
const void* acc = nullptr) {
103+
cudaStream_t stream) {
103104
if (reduceScatterMgr_ == nullptr) {
104105
return nullptr;
105106
}
106107
return reduceScatterMgr_->getReduceScatterAlgo(
107-
sendbuff, recvbuff, count, datatype, stream, acc);
108+
sendbuff, recvbuff, count, datatype, stream);
108109
}
109110

110111
std::unique_ptr<AlgoAllToAll> getAllToAllAlgo(
111112
const void* sendbuff,
112113
void* recvbuff,
113114
size_t count,
114115
commDataType_t datatype,
115-
cudaStream_t stream,
116-
const void* acc = nullptr) {
116+
cudaStream_t stream) {
117117
if (allToAllMgr_ == nullptr) {
118118
return nullptr;
119119
}
120120
return allToAllMgr_->getAllToAllAlgo(
121-
sendbuff, recvbuff, count, datatype, stream, acc);
121+
sendbuff, recvbuff, count, datatype, stream);
122122
}
123123

124124
private:
125+
int nRanks_{0};
126+
int selfRank_{-1};
127+
int maxBlocks_{0};
128+
int ddaSendbufSizeBytes_{0};
129+
130+
std::unique_ptr<IpcGpuBarrierResources> barrierResources_;
131+
IpcGpuBarrier barrier_;
132+
std::unique_ptr<DeviceBuffer> ddaSendbuf_;
133+
std::unique_ptr<IpcMemHandler> memHandler_;
134+
// arrary of void* (all ranks' ipc enabled sendbuf) in device memory
135+
std::unique_ptr<DeviceBuffer> allRankDdaSendbuffs_;
136+
125137
std::unique_ptr<AllReduceAlgoManager> allReduceMgr_{nullptr};
126138
std::unique_ptr<AllGatherAlgoManager> allGatherMgr_{nullptr};
127139
std::unique_ptr<ReduceScatterAlgoManager> reduceScatterMgr_{nullptr};

comms/common/algorithms/all_gather/AlgoAllGather.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ AlgoAllGather::AlgoAllGather(
1515
int nRanks,
1616
int selfRank,
1717
int maxBlocks,
18-
IpcGpuBarrier* barrier,
19-
const void* acc)
18+
IpcGpuBarrier* barrier)
2019
: sendbuff_(sendbuff),
2120
allRankDdaSendbuffs_(allRankDdaSendbuffs),
2221
recvbuff_(recvbuff),
@@ -26,8 +25,7 @@ AlgoAllGather::AlgoAllGather(
2625
nRanks_(nRanks),
2726
selfRank_(selfRank),
2827
maxBlocks_(maxBlocks),
29-
barrier_(barrier),
30-
acc_(acc) {}
28+
barrier_(barrier) {}
3129

3230
void AlgoAllGatherDdaIpc::allGather() {
3331
TYPED_CALL(datatype_, launchKernel);

comms/common/algorithms/all_gather/AlgoAllGather.cuh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ class AlgoAllGather {
2929
int nRanks,
3030
int selfRank,
3131
int maxBlocks,
32-
IpcGpuBarrier* barrier,
33-
const void* acc);
32+
IpcGpuBarrier* barrier);
3433

3534
virtual ~AlgoAllGather() = default;
3635

@@ -47,7 +46,6 @@ class AlgoAllGather {
4746
int selfRank_{0};
4847
const size_t maxBlocks_{0};
4948
IpcGpuBarrier* barrier_;
50-
const void* acc_{nullptr};
5149
};
5250

5351
class AlgoAllGatherDdaIpc : public AlgoAllGather {
@@ -74,8 +72,7 @@ class AlgoAllGatherDdaIpc : public AlgoAllGather {
7472
&count_,
7573
&sendbuff_,
7674
&selfRank_,
77-
barrier_,
78-
&acc_};
75+
barrier_};
7976
CUDA_CHECK(cudaLaunchKernel(func, grid, block, args, 0, stream_));
8077
}
8178
};

comms/common/algorithms/all_gather/AllGatherAlgoManager.cu

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,20 @@
55
namespace meta::comms {
66

77
AllGatherAlgoManager::AllGatherAlgoManager(
8-
std::shared_ptr<ctran::bootstrap::IBootstrap> bootstrap,
98
int nRanks,
109
int selfRank,
1110
int maxBlocks,
1211
int ddaSendbufSizeBytes,
13-
int ddaMaxThresholdBytes)
12+
int ddaMaxThresholdBytes,
13+
void** allRankDdaSendbuffs,
14+
IpcGpuBarrier* barrier)
1415
: nRanks_(nRanks),
1516
selfRank_(selfRank),
1617
maxBlocks_(maxBlocks),
1718
ddaSendbufSizeBytes_(ddaSendbufSizeBytes),
18-
ddaMaxThresholdBytes_(ddaMaxThresholdBytes) {
19-
auto [barrierResources, barrier] =
20-
IpcGpuBarrier::mallocAndInit(nRanks_, maxBlocks_, selfRank_, bootstrap);
21-
barrierResources_ = std::move(barrierResources);
22-
barrier_ = barrier;
23-
24-
ddaSendbuf_ = std::make_unique<DeviceBuffer>(ddaSendbufSizeBytes_);
25-
memHandler_ = std::make_unique<IpcMemHandler>(bootstrap, selfRank_, nRanks_);
26-
memHandler_->addSelfDeviceMemPtr(ddaSendbuf_->get());
27-
memHandler_->exchangeMemPtrs();
28-
29-
std::vector<void*> ipcSendbufs(nRanks_);
30-
for (int i = 0; i < nRanks_; ++i) {
31-
ipcSendbufs[i] = memHandler_->getPeerDeviceMemPtr(i);
32-
}
33-
34-
allRankDdaSendbuffs_ =
35-
std::make_unique<DeviceBuffer>(sizeof(void*) * nRanks_);
36-
CUDA_CHECK(cudaMemcpy(
37-
allRankDdaSendbuffs_->get(),
38-
ipcSendbufs.data(),
39-
sizeof(void*) * nRanks_,
40-
cudaMemcpyDefault));
19+
ddaMaxThresholdBytes_(ddaMaxThresholdBytes),
20+
allRankDdaSendbuffs_(allRankDdaSendbuffs),
21+
barrier_(barrier) {
4122
XLOG(DBG) << "Successfully initialized AllGatherAlgoManager";
4223
}
4324

@@ -46,15 +27,15 @@ std::unique_ptr<AlgoAllGather> AllGatherAlgoManager::getAllGatherAlgo(
4627
void* recvbuff,
4728
size_t count,
4829
commDataType_t datatype,
49-
cudaStream_t stream,
50-
const void* acc) {
51-
if (count * commTypeSize(datatype) > ddaSendbufSizeBytes_) {
52-
// msg size must fit into the dda sendbuf
30+
cudaStream_t stream) {
31+
if ((nRanks_ * count * commTypeSize(datatype)) > ddaSendbufSizeBytes_) {
32+
// AG: msgSize = (nRanks_ x count x datatype) must fit into the dda sendbuf
5333
XLOG(DBG) << "Not using custom all gather algo because message size "
54-
<< count * commTypeSize(datatype)
34+
<< nRanks_ * count * commTypeSize(datatype)
5535
<< " is larger than ddaSendbufSizeBytes " << ddaSendbufSizeBytes_;
5636
return nullptr;
5737
}
38+
5839
if (((uintptr_t)sendbuff % 16) || ((uintptr_t)recvbuff % 16) ||
5940
((count * commTypeSize(datatype)) % 16)) {
6041
// 16 byte alignment as we do 16-byte loads in DDA kernel
@@ -72,29 +53,30 @@ std::unique_ptr<AlgoAllGather> AllGatherAlgoManager::getAllGatherAlgo(
7253
}
7354

7455
std::unique_ptr<AlgoAllGather> algo;
75-
if (count * commTypeSize(datatype) > ddaMaxThresholdBytes_) {
56+
if ((nRanks_ * count * commTypeSize(datatype)) > ddaMaxThresholdBytes_) {
57+
// AG: msgSize = (nRanks_ x count x datatype) must less than algo threshold
7658
XLOG(DBG) << "Not using custom all gather algo because msg size "
77-
<< count * commTypeSize(datatype)
59+
<< nRanks_ * count * commTypeSize(datatype)
7860
<< " is larger than DDA algo threshold " << ddaMaxThresholdBytes_;
7961
return nullptr;
8062
} else {
81-
if ((count * commTypeSize(datatype)) % 16) {
63+
if (((count * commTypeSize(datatype)) % 16) ||
64+
((nRanks_ * count * commTypeSize(datatype)) % 16)) {
8265
XLOG(DBG) << "Not using DDA all gather algo because send/recv buff "
8366
"or msg size is not 16-byte aligned for each rank";
8467
return nullptr;
8568
}
8669
algo = std::make_unique<AlgoAllGatherDdaIpc>(
8770
sendbuff,
88-
reinterpret_cast<void**>(allRankDdaSendbuffs_->get()),
71+
allRankDdaSendbuffs_,
8972
recvbuff,
9073
count,
9174
datatype,
9275
stream,
9376
nRanks_,
9477
selfRank_,
9578
maxBlocks_,
96-
&barrier_,
97-
acc);
79+
barrier_);
9880
}
9981
return algo;
10082
}

0 commit comments

Comments
 (0)