Skip to content

Commit 89bde53

Browse files
authored
Optimize CAR for ROCm (#225)
* Optimize CAR for ROCm * tune block numbers * inrease cutoff to RCCL fallback to 16 MB * scope atomics * remove volatiles * Pacify linters.
1 parent 2550f14 commit 89bde53

File tree

3 files changed

+107
-64
lines changed

3 files changed

+107
-64
lines changed

csrc/custom_all_reduce.cuh

Lines changed: 93 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ struct __align__(16) RankData { const void* ptrs[8]; };
4343
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
4444
#endif
4545

46-
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
46+
struct __align__(16) RankSignals {
47+
#ifndef USE_ROCM
48+
volatile
49+
#endif
50+
Signal* signals[8];
51+
};
4752

4853
// like std::array, but aligned
4954
template <typename T, int sz>
@@ -138,18 +143,23 @@ DINLINE O downcast(array_t<float, O::size> val) {
138143
// prior memory accesses. Note: volatile writes will not be reordered against
139144
// other volatile writes.
140145
template <int ngpus>
141-
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
146+
DINLINE void start_sync(const RankSignals& sg,
147+
#ifndef USE_ROCM
148+
volatile
149+
#endif
150+
Signal* self_sg,
142151
int rank) {
143152
#ifdef USE_ROCM
144153
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
145154
if (threadIdx.x < ngpus) {
146155
// simultaneously write to the corresponding flag of all ranks.
147156
// Latency = 1 p2p write
148-
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
149-
__ATOMIC_RELAXED);
157+
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
158+
flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
150159
// wait until we got true from all ranks
151-
while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
152-
__ATOMIC_RELAXED) < flag);
160+
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
161+
__ATOMIC_RELAXED,
162+
__MEMORY_SCOPE_DEVICE) < flag);
153163
}
154164
__syncthreads();
155165
// use one thread to update flag
@@ -172,7 +182,11 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
172182
// barrier in the all reduce kernel. If it's the final synchronization barrier,
173183
// we don't need to make any visibility guarantees for prior memory accesses.
174184
template <int ngpus, bool final_sync = false>
175-
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
185+
DINLINE void end_sync(const RankSignals& sg,
186+
#ifndef USE_ROCM
187+
volatile
188+
#endif
189+
Signal* self_sg,
176190
int rank) {
177191
#ifdef USE_ROCM
178192
__syncthreads();
@@ -184,12 +198,15 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
184198
if (threadIdx.x < ngpus) {
185199
// simultaneously write to the corresponding flag of all ranks.
186200
// Latency = 1 p2p write
187-
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
188-
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
201+
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
202+
flag,
203+
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
204+
__MEMORY_SCOPE_SYSTEM);
189205
// wait until we got true from all ranks
190-
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
191-
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
192-
flag);
206+
while (
207+
__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
208+
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
209+
__MEMORY_SCOPE_DEVICE) < flag);
193210
}
194211
__syncthreads();
195212
// use one thread to update flag
@@ -227,8 +244,11 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
227244
template <typename T, int ngpus>
228245
__global__ void __launch_bounds__(512, 1)
229246
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
230-
volatile Signal* self_sg, T* __restrict__ result,
231-
int rank, int size) {
247+
#ifndef USE_ROCM
248+
volatile
249+
#endif
250+
Signal* self_sg,
251+
T* __restrict__ result, int rank, int size) {
232252
using P = typename packed_t<T>::P;
233253
using A = typename packed_t<T>::A;
234254
// note: we don't reorder the address so the accumulation order is the same
@@ -244,15 +264,22 @@ __global__ void __launch_bounds__(512, 1)
244264
}
245265

246266
template <typename P>
267+
#ifdef USE_ROCM
268+
DINLINE P* get_tmp_buf(Signal* sg) {
269+
#else
247270
DINLINE P* get_tmp_buf(volatile Signal* sg) {
271+
#endif
248272
return (P*)(((Signal*)sg) + 1);
249273
}
250274

251275
template <typename T, int ngpus>
252276
__global__ void __launch_bounds__(512, 1)
253277
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
254-
volatile Signal* self_sg, T* __restrict__ result,
255-
int rank, int size) {
278+
#ifndef USE_ROCM
279+
volatile
280+
#endif
281+
Signal* self_sg,
282+
T* __restrict__ result, int rank, int size) {
256283
int tid = blockIdx.x * blockDim.x + threadIdx.x;
257284
int stride = gridDim.x * blockDim.x;
258285
using P = typename packed_t<T>::P;
@@ -455,37 +482,41 @@ class CustomAllreduce {
455482
*/
456483
template <typename T>
457484
void allreduce(cudaStream_t stream, T* input, T* output, int size,
458-
int threads = 512, int block_limit = 36) {
459-
auto d = packed_t<T>::P::size;
460-
if (size % d != 0)
485+
#ifndef USE_ROCM
486+
int threads = 512, int block_limit = 36){
487+
#else
488+
int threads = 512, int block_limit = 16) {
489+
#endif
490+
auto d = packed_t<T>::P::size;
491+
if (size % d != 0)
492+
throw std::runtime_error(
493+
"custom allreduce currently requires input length to be multiple "
494+
"of " +
495+
std::to_string(d));
496+
if (block_limit > kMaxBlocks)
497+
throw std::runtime_error("max supported block limit is " +
498+
std::to_string(kMaxBlocks) + ". Got " +
499+
std::to_string(block_limit));
500+
501+
RankData* ptrs;
502+
cudaStreamCaptureStatus status;
503+
CUDACHECK(cudaStreamIsCapturing(stream, &status));
504+
if (status == cudaStreamCaptureStatusActive) {
505+
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
506+
graph_unreg_buffers_.push_back(input);
507+
} else {
508+
auto it = buffers_.find(input);
509+
if (it == buffers_.end())
461510
throw std::runtime_error(
462-
"custom allreduce currently requires input length to be multiple "
463-
"of " +
464-
std::to_string(d));
465-
if (block_limit > kMaxBlocks)
466-
throw std::runtime_error("max supported block limit is " +
467-
std::to_string(kMaxBlocks) + ". Got " +
468-
std::to_string(block_limit));
469-
470-
RankData* ptrs;
471-
cudaStreamCaptureStatus status;
472-
CUDACHECK(cudaStreamIsCapturing(stream, &status));
473-
if (status == cudaStreamCaptureStatusActive) {
474-
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
475-
graph_unreg_buffers_.push_back(input);
476-
} else {
477-
auto it = buffers_.find(input);
478-
if (it == buffers_.end())
479-
throw std::runtime_error(
480-
"buffer address " +
481-
std::to_string(reinterpret_cast<uint64_t>(input)) +
482-
" is not registered!");
483-
ptrs = it->second;
484-
}
511+
"buffer address " +
512+
std::to_string(reinterpret_cast<uint64_t>(input)) +
513+
" is not registered!");
514+
ptrs = it->second;
515+
}
485516

486-
size /= d;
487-
auto bytes = size * sizeof(typename packed_t<T>::P);
488-
int blocks = std::min(block_limit, (size + threads - 1) / threads);
517+
size /= d;
518+
auto bytes = size * sizeof(typename packed_t<T>::P);
519+
int blocks = std::min(block_limit, (size + threads - 1) / threads);
489520
#define KL(ngpus, name) \
490521
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
491522
rank_, size);
@@ -504,27 +535,27 @@ class CustomAllreduce {
504535
break; \
505536
}
506537

507-
switch (world_size_) {
508-
REDUCE_CASE(2)
509-
REDUCE_CASE(4)
510-
REDUCE_CASE(6)
511-
REDUCE_CASE(8)
512-
default:
513-
throw std::runtime_error(
514-
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
515-
"gpus = " +
516-
std::to_string(world_size_));
517-
}
538+
switch (world_size_) {
539+
REDUCE_CASE(2)
540+
REDUCE_CASE(4)
541+
REDUCE_CASE(6)
542+
REDUCE_CASE(8)
543+
default:
544+
throw std::runtime_error(
545+
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
546+
"gpus = " +
547+
std::to_string(world_size_));
548+
}
518549
#undef REDUCE_CASE
519550
#undef KL
520-
}
551+
}
521552

522-
~CustomAllreduce() {
523-
for (auto [_, ptr] : ipc_handles_) {
524-
CUDACHECK(cudaIpcCloseMemHandle(ptr));
525-
}
553+
~CustomAllreduce() {
554+
for (auto [_, ptr] : ipc_handles_) {
555+
CUDACHECK(cudaIpcCloseMemHandle(ptr));
526556
}
527-
};
557+
}
558+
}; // namespace vllm
528559
/**
529560
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
530561
a template instantiation:

csrc/custom_all_reduce_test.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,17 @@ int main(int argc, char** argv) {
330330
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
331331
// }
332332
// }
333+
#ifdef USE_ROCM
334+
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
335+
run<half>(myRank, nRanks, comm, 512, 16, sz + 8 * 47, performance_test);
336+
}
337+
#else
333338
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
334339
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
335340
}
341+
#endif
336342

337343
cudaProfilerStop();
344+
MPICHECK(MPI_Finalize());
338345
return EXIT_SUCCESS;
339346
}

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.distributed.parallel_state import in_the_same_node_as
1313
from vllm.logger import init_logger
1414
from vllm.platforms import current_platform
15-
from vllm.utils import cuda_device_count_stateless
15+
from vllm.utils import cuda_device_count_stateless, is_hip
1616

1717
try:
1818
ops.meta_size()
@@ -44,10 +44,15 @@ class CustomAllreduce:
4444
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
4545

4646
# max_size: max supported allreduce size
47+
_MAX_CAR_SIZE = 8192 * 1024
48+
if is_hip():
49+
# crossover is at 16MB buffer size for ROCm
50+
_MAX_CAR_SIZE = 2 * 8192 * 1024
51+
4752
def __init__(self,
4853
group: ProcessGroup,
4954
device: Union[int, str, torch.device],
50-
max_size=8192 * 1024) -> None:
55+
max_size=_MAX_CAR_SIZE) -> None:
5156
"""
5257
Args:
5358
group: the process group to work on. If None, it will use the

0 commit comments

Comments
 (0)