Skip to content

Commit 61abf03

Browse files
Liu Kemeta-codesync[bot]
authored andcommitted
shared func in CollCommon.cuh for AR, AG, RS, A2A
Summary: This diff: - allGather, reduceScatter, copyFromSrcToDest as shared functions in CollCommon.cuh - ddaAllReduceTreeIpc: 1) allGather, 2) reduceScatter - ddaAllReduceFlatIpc: 1) copyFromSrcToDest, 2) reduceScatter - ddaAllGatherIpc: 1) copyFromSrcToDest, 2) allGather - ddaReduceScatterIpc: 1) reduceScatter No major performance regression before vs. after this change. - ~1us latency increment on RS, A2A in the range 1KB ~ 16KB, but then start to reduce latency after 32KB --> consider positive https://docs.google.com/spreadsheets/d/1Q4BHly_9ht8nbvt2IGK1XFDVlzIrevPTZ_0WI6zePew/edit?usp=sharing Reviewed By: cenzhaometa Differential Revision: D86485648 fbshipit-source-id: db7ebcbd6da13d9f5e23c6419858a5a01ffe4629
1 parent e569982 commit 61abf03

File tree

10 files changed

+276
-203
lines changed

10 files changed

+276
-203
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#pragma once
4+
5+
#include <cuda.h>
6+
#include <cuda_bf16.h>
7+
#include <cuda_fp16.h>
8+
9+
namespace meta::comms {
10+
11+
template <typename T>
12+
concept SupportedTypes =
13+
(std::same_as<T, half> || std::same_as<T, __nv_bfloat16>);
14+
15+
template <SupportedTypes T>
16+
static inline __device__ uint32_t
17+
vecElementAdd(const uint32_t& a, const uint32_t& b) {
18+
if constexpr (std::is_same<T, half>::value) {
19+
const __half* x = reinterpret_cast<const __half*>(&a);
20+
const __half* y = reinterpret_cast<const __half*>(&b);
21+
__half2 p = __halves2half2(x[0], x[1]);
22+
__half2 q = __halves2half2(y[0], y[1]);
23+
__half2 z = __hadd2(p, q);
24+
return (reinterpret_cast<uint32_t*>(&z))[0];
25+
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
26+
const __nv_bfloat16* x = reinterpret_cast<const __nv_bfloat16*>(&a);
27+
const __nv_bfloat16* y = reinterpret_cast<const __nv_bfloat16*>(&b);
28+
__nv_bfloat162 p = {x[0], x[1]};
29+
__nv_bfloat162 q = {y[0], y[1]};
30+
__nv_bfloat162 z = __hadd2(p, q);
31+
return (reinterpret_cast<uint32_t*>(&z))[0];
32+
}
33+
return 0;
34+
}
35+
36+
template <SupportedTypes T>
37+
static inline __device__ uint4 vecElementAdd(const uint4& a, const uint4& b) {
38+
uint4 res{0, 0, 0, 0};
39+
res.x = vecElementAdd<T>(a.x, b.x);
40+
res.y = vecElementAdd<T>(a.y, b.y);
41+
res.z = vecElementAdd<T>(a.z, b.z);
42+
res.w = vecElementAdd<T>(a.w, b.w);
43+
return res;
44+
}
45+
46+
template <SupportedTypes T>
47+
static inline __device__ void copyFromSrcToDest(
48+
const T* __restrict__ srcbuff,
49+
T* __restrict__ destbuff,
50+
const size_t idxStart,
51+
const size_t idxEnd,
52+
const size_t idxStride) {
53+
for (size_t idx = idxStart; idx < idxEnd; idx += idxStride) {
54+
*reinterpret_cast<uint4*>(&destbuff[idx]) =
55+
reinterpret_cast<const uint4*>(&srcbuff[idx])[0];
56+
}
57+
}
58+
59+
template <SupportedTypes T, int NRANKS, bool hasAcc>
60+
static inline __device__ void reduceScatter(
61+
T* const* __restrict__ ipcbuffs,
62+
T* __restrict__ destbuff,
63+
const T* __restrict__ acc,
64+
int selfRank,
65+
const size_t idxStart,
66+
const size_t idxEnd,
67+
const size_t idxStride,
68+
int pattern) {
69+
/*
70+
This reduceScatter func handles three different patterns:
71+
- enable_offset is used to pick between the two patterns
72+
73+
1st pattern: ReduceScatter itself
74+
Rank 0: chunk 0 | chunk 1 | chunk 2 | chunk 3
75+
Rank 1: chunk 0 | chunk 1 | chunk 2 | chunk 3
76+
Rank 2: chunk 0 | chunk 1 | chunk 2 | chunk 3
77+
Rank 3: chunk 0 | chunk 1 | chunk 2 | chunk 3
78+
---> reduceScatter -->
79+
Rank 0: sum 0
80+
Rank 1: sum 1
81+
Rank 2: sum 2
82+
Rank 3: sum 3
83+
84+
2nd pattern: ReduceScatter as the 2nd step inside AllReduce-Tree (RS + AG)
85+
Rank 0: chunk 0 | chunk 1 | chunk 2 | chunk 3
86+
Rank 1: chunk 0 | chunk 1 | chunk 2 | chunk 3
87+
Rank 2: chunk 0 | chunk 1 | chunk 2 | chunk 3
88+
Rank 3: chunk 0 | chunk 1 | chunk 2 | chunk 3
89+
---> reduceScatter -->
90+
Rank 0: sum 0 | - | - | -
91+
Rank 1: - | sum 1 | - | -
92+
Rank 2: - | - | sum 2 | -
93+
Rank 3: - | - | - | sum 3
94+
95+
3rd pattern: reduce for AllReduce-Flat
96+
Rank 0: chunk 0 | chunk 1 | chunk 2 | chunk 3
97+
Rank 1: chunk 0 | chunk 1 | chunk 2 | chunk 3
98+
Rank 2: chunk 0 | chunk 1 | chunk 2 | chunk 3
99+
Rank 3: chunk 0 | chunk 1 | chunk 2 | chunk 3
100+
---> reduce -->
101+
Rank 0: sum 0 | sum 1 | sum 2 | sum 3
102+
Rank 1: sum 0 | sum 1 | sum 2 | sum 3
103+
Rank 2: sum 0 | sum 1 | sum 2 | sum 3
104+
Rank 3: sum 0 | sum 1 | sum 2 | sum 3
105+
*/
106+
107+
for (size_t idx = idxStart; idx < idxEnd; idx += idxStride) {
108+
size_t srcIdx = (pattern == 2) ? idx : (idx + selfRank * idxEnd);
109+
size_t destIdx = (pattern == 1) ? (idx + selfRank * idxEnd) : idx;
110+
111+
uint4 sum{0, 0, 0, 0};
112+
// TODO: The bias accumulation needs to be moved to stage 2 if the bias
113+
// vector can be different on each rank. Currently we assume the bias vector
114+
// is the same across ranks.
115+
if constexpr (hasAcc) {
116+
sum = reinterpret_cast<const uint4*>(&acc[srcIdx])[0];
117+
}
118+
119+
// Pipelining read val from other ranks and accumulation
120+
uint4 srcVals[2];
121+
// Prologue: read data from first rank
122+
*reinterpret_cast<uint4*>(&srcVals[0]) =
123+
reinterpret_cast<const uint4*>(&ipcbuffs[0][srcIdx])[0];
124+
#pragma unroll NRANKS - 1
125+
for (int r = 0; r < NRANKS - 1; ++r) {
126+
// Kick-off reading data from next rank
127+
*reinterpret_cast<uint4*>(&srcVals[(r + 1) & 1]) =
128+
reinterpret_cast<const uint4*>(
129+
&ipcbuffs[(r + 1) % NRANKS][srcIdx])[0];
130+
// Do accumulation for current rank
131+
sum = vecElementAdd<T>(sum, srcVals[r & 1]);
132+
}
133+
// Epilogue: accumulation for last rank
134+
sum = vecElementAdd<T>(sum, srcVals[(NRANKS - 1) & 1]);
135+
136+
// Store to the destination buffer
137+
*reinterpret_cast<uint4*>(&destbuff[destIdx]) =
138+
*reinterpret_cast<const uint4*>(&sum);
139+
}
140+
}
141+
142+
template <SupportedTypes T, int NRANKS>
143+
static inline __device__ void allGather(
144+
T* const* __restrict__ ipcbuffs,
145+
T* __restrict__ destbuff,
146+
int selfRank,
147+
const size_t idxStart,
148+
const size_t idxEnd,
149+
const size_t idxStride,
150+
bool enable_offset) {
151+
/*
152+
This allGather func handles two different patterns:
153+
- enable_offset is used to pick between the two patterns
154+
155+
1st pattern: AllGather itself
156+
Rank 0: chunk 0
157+
Rank 1: chunk 1
158+
Rank 2: chunk 2
159+
Rank 3: chunk 3
160+
---> AllGather -->
161+
Rank 0: chunk 0 | chunk 1 | chunk 2 | chunk 3
162+
Rank 1: chunk 0 | chunk 1 | chunk 2 | chunk 3
163+
Rank 2: chunk 0 | chunk 1 | chunk 2 | chunk 3
164+
Rank 3: chunk 0 | chunk 1 | chunk 2 | chunk 3
165+
166+
2nd pattern: AllGather as the 2nd step inside AllReduce (RS + AG)
167+
Rank 0: chunk 0 | - | - | -
168+
Rank 1: - | chunk 1 | - | -
169+
Rank 2: - | - | chunk 2 | -
170+
Rank 3: - | - | - | chunk 3
171+
---> AllGather -->
172+
Rank 0: chunk 0 | chunk 1 | chunk 2 | chunk 3
173+
Rank 1: chunk 0 | chunk 1 | chunk 2 | chunk 3
174+
Rank 2: chunk 0 | chunk 1 | chunk 2 | chunk 3
175+
Rank 3: chunk 0 | chunk 1 | chunk 2 | chunk 3
176+
*/
177+
178+
for (size_t idx = idxStart; idx < idxEnd; idx += idxStride) {
179+
#pragma unroll NRANKS
180+
for (int r = 0; r < NRANKS; ++r) {
181+
int srcRank = (selfRank + r) % NRANKS;
182+
int destIdx = idx + srcRank * idxEnd;
183+
int srcIdx;
184+
if (enable_offset) {
185+
srcIdx = destIdx;
186+
} else {
187+
srcIdx = idx;
188+
}
189+
*reinterpret_cast<uint4*>(&destbuff[destIdx]) =
190+
reinterpret_cast<const uint4*>(&ipcbuffs[srcRank][srcIdx])[0];
191+
}
192+
}
193+
}
194+
195+
} // namespace meta::comms

comms/common/algorithms/VecElementAdd.cuh

Lines changed: 0 additions & 46 deletions
This file was deleted.

comms/common/algorithms/all_gather/all_gather_dda.cuh

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cuda_bf16.h>
55
#include <cuda_fp16.h>
66
#include "comms/common/IpcGpuBarrier.cuh"
7+
#include "comms/common/algorithms/CollCommon.cuh"
78

89
namespace meta::comms {
910

@@ -22,34 +23,25 @@ __launch_bounds__(512)
2223
// We assume that count % countPerThread == 0. This assumption is enforced
2324
// before kernel launch
2425
// TODO: we should be able to deal with left over as well
26+
const size_t countPerRank = count;
2527
constexpr auto countPerThread = sizeof(uint4) / sizeof(T);
26-
const auto idxStride = gridDim.x * blockDim.x * countPerThread;
2728
const auto gtIdx = blockDim.x * blockIdx.x + threadIdx.x;
29+
2830
const auto idxStart = gtIdx * countPerThread;
29-
const auto idxEnd = count;
31+
const auto idxEnd = countPerRank;
32+
const auto idxStride = gridDim.x * blockDim.x * countPerThread;
3033

3134
// It is expensive to launch hipMemcpyAsync on ROCm
3235
// Move data copy here. Each block copies part of sendbuff data
33-
T* ipcbuff = ipcbuffs[selfRank];
34-
for (size_t idx = idxStart; idx < idxEnd; idx += idxStride) {
35-
*reinterpret_cast<uint4*>(&ipcbuff[idx]) =
36-
reinterpret_cast<const uint4*>(&sendbuff[idx])[0];
37-
}
36+
copyFromSrcToDest<T>(
37+
sendbuff, ipcbuffs[selfRank], idxStart, idxEnd, idxStride);
3838

3939
barrier.syncOnSameBlockIdx<
4040
true /* hasPreviousMemAccess */,
4141
true /* hasSubsequentMemAccess */>();
4242

43-
for (size_t idx = idxStart; idx < idxEnd; idx += idxStride) {
44-
// Store to the destination buffer.
45-
#pragma unroll NRANKS
46-
for (int r = 0; r < NRANKS; ++r) {
47-
int srcRank = (selfRank + r) % NRANKS;
48-
int srcIdx = idx + srcRank * idxEnd;
49-
*reinterpret_cast<uint4*>(&recvbuff[srcIdx]) =
50-
reinterpret_cast<const uint4*>(&ipcbuffs[srcRank][idx])[0];
51-
}
52-
}
43+
allGather<T, NRANKS>(
44+
ipcbuffs, recvbuff, selfRank, idxStart, idxEnd, idxStride, false);
5345

5446
// barrier to ensure remote ranks won't free their buffers until I'm done
5547
barrier.syncOnSameBlockIdx<

0 commit comments

Comments
 (0)