@@ -43,7 +43,12 @@ struct __align__(16) RankData { const void* ptrs[8]; };
4343struct __align__ (16 ) RankData { const void * __restrict__ ptrs[8 ]; };
4444#endif
4545
46- struct __align__ (16 ) RankSignals { volatile Signal* signals[8 ]; };
46+ struct __align__ (16 ) RankSignals {
47+ #ifndef USE_ROCM
48+ volatile
49+ #endif
50+ Signal* signals[8 ];
51+ };
4752
4853// like std::array, but aligned
4954template <typename T, int sz>
@@ -138,18 +143,23 @@ DINLINE O downcast(array_t<float, O::size> val) {
138143// prior memory accesses. Note: volatile writes will not be reordered against
139144// other volatile writes.
140145template <int ngpus>
141- DINLINE void start_sync (const RankSignals& sg, volatile Signal* self_sg,
146+ DINLINE void start_sync (const RankSignals& sg,
147+ #ifndef USE_ROCM
148+ volatile
149+ #endif
150+ Signal* self_sg,
142151 int rank) {
143152#ifdef USE_ROCM
144153 uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
145154 if (threadIdx .x < ngpus) {
146155 // simultaneously write to the corresponding flag of all ranks.
147156 // Latency = 1 p2p write
148- __atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank], flag ,
149- __ATOMIC_RELAXED);
157+ __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank],
158+ flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM );
150159 // wait until we got true from all ranks
151- while (__atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
152- __ATOMIC_RELAXED) < flag);
160+ while (__scoped_atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
161+ __ATOMIC_RELAXED,
162+ __MEMORY_SCOPE_DEVICE) < flag);
153163 }
154164 __syncthreads ();
155165 // use one thread to update flag
@@ -172,7 +182,11 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
172182// barrier in the all reduce kernel. If it's the final synchronization barrier,
173183// we don't need to make any visibility guarantees for prior memory accesses.
174184template <int ngpus, bool final_sync = false >
175- DINLINE void end_sync (const RankSignals& sg, volatile Signal* self_sg,
185+ DINLINE void end_sync (const RankSignals& sg,
186+ #ifndef USE_ROCM
187+ volatile
188+ #endif
189+ Signal* self_sg,
176190 int rank) {
177191#ifdef USE_ROCM
178192 __syncthreads ();
@@ -184,12 +198,15 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
184198 if (threadIdx .x < ngpus) {
185199 // simultaneously write to the corresponding flag of all ranks.
186200 // Latency = 1 p2p write
187- __atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], flag,
188- final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
201+ __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank],
202+ flag,
203+ final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
204+ __MEMORY_SCOPE_SYSTEM);
189205 // wait until we got true from all ranks
190- while (__atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
191- final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
192- flag);
206+ while (
207+ __scoped_atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
208+ final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
209+ __MEMORY_SCOPE_DEVICE) < flag);
193210 }
194211 __syncthreads ();
195212 // use one thread to update flag
@@ -227,8 +244,11 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
227244template <typename T, int ngpus>
228245__global__ void __launch_bounds__ (512 , 1 )
229246 cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
230- volatile Signal* self_sg, T* __restrict__ result,
231- int rank, int size) {
247+ #ifndef USE_ROCM
248+ volatile
249+ #endif
250+ Signal* self_sg,
251+ T* __restrict__ result, int rank, int size) {
232252 using P = typename packed_t <T>::P;
233253 using A = typename packed_t <T>::A;
234254 // note: we don't reorder the address so the accumulation order is the same
@@ -244,15 +264,22 @@ __global__ void __launch_bounds__(512, 1)
244264}
245265
246266template <typename P>
267+ #ifdef USE_ROCM
268+ DINLINE P* get_tmp_buf (Signal* sg) {
269+ #else
247270DINLINE P* get_tmp_buf (volatile Signal* sg) {
271+ #endif
248272 return (P*)(((Signal*)sg) + 1 );
249273}
250274
251275template <typename T, int ngpus>
252276__global__ void __launch_bounds__ (512 , 1 )
253277 cross_device_reduce_2stage (RankData* _dp, RankSignals sg,
254- volatile Signal* self_sg, T* __restrict__ result,
255- int rank, int size) {
278+ #ifndef USE_ROCM
279+ volatile
280+ #endif
281+ Signal* self_sg,
282+ T* __restrict__ result, int rank, int size) {
256283 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
257284 int stride = gridDim .x * blockDim .x ;
258285 using P = typename packed_t <T>::P;
@@ -455,37 +482,41 @@ class CustomAllreduce {
455482 */
456483 template <typename T>
457484 void allreduce (cudaStream_t stream, T* input, T* output, int size,
458- int threads = 512 , int block_limit = 36 ) {
459- auto d = packed_t <T>::P::size;
460- if (size % d != 0 )
485+ #ifndef USE_ROCM
486+ int threads = 512 , int block_limit = 36 ){
487+ #else
488+ int threads = 512 , int block_limit = 16 ) {
489+ #endif
490+ auto d = packed_t <T>::P::size;
491+ if (size % d != 0 )
492+ throw std::runtime_error (
493+ " custom allreduce currently requires input length to be multiple "
494+ " of " +
495+ std::to_string (d));
496+ if (block_limit > kMaxBlocks )
497+ throw std::runtime_error (" max supported block limit is " +
498+ std::to_string (kMaxBlocks ) + " . Got " +
499+ std::to_string (block_limit));
500+
501+ RankData* ptrs;
502+ cudaStreamCaptureStatus status;
503+ CUDACHECK (cudaStreamIsCapturing (stream, &status));
504+ if (status == cudaStreamCaptureStatusActive) {
505+ ptrs = d_rank_data_base_ + graph_unreg_buffers_.size ();
506+ graph_unreg_buffers_.push_back (input);
507+ } else {
508+ auto it = buffers_.find (input);
509+ if (it == buffers_.end ())
461510 throw std::runtime_error (
462- " custom allreduce currently requires input length to be multiple "
463- " of " +
464- std::to_string (d));
465- if (block_limit > kMaxBlocks )
466- throw std::runtime_error (" max supported block limit is " +
467- std::to_string (kMaxBlocks ) + " . Got " +
468- std::to_string (block_limit));
469-
470- RankData* ptrs;
471- cudaStreamCaptureStatus status;
472- CUDACHECK (cudaStreamIsCapturing (stream, &status));
473- if (status == cudaStreamCaptureStatusActive) {
474- ptrs = d_rank_data_base_ + graph_unreg_buffers_.size ();
475- graph_unreg_buffers_.push_back (input);
476- } else {
477- auto it = buffers_.find (input);
478- if (it == buffers_.end ())
479- throw std::runtime_error (
480- " buffer address " +
481- std::to_string (reinterpret_cast <uint64_t >(input)) +
482- " is not registered!" );
483- ptrs = it->second ;
484- }
511+ " buffer address " +
512+ std::to_string (reinterpret_cast <uint64_t >(input)) +
513+ " is not registered!" );
514+ ptrs = it->second ;
515+ }
485516
486- size /= d;
487- auto bytes = size * sizeof (typename packed_t <T>::P);
488- int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
517+ size /= d;
518+ auto bytes = size * sizeof (typename packed_t <T>::P);
519+ int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
489520#define KL (ngpus, name ) \
490521 name<T, ngpus><<<blocks, threads, 0 , stream>>> (ptrs, sg_, self_sg_, output, \
491522 rank_, size);
@@ -504,27 +535,27 @@ class CustomAllreduce {
504535 break ; \
505536 }
506537
507- switch (world_size_) {
508- REDUCE_CASE (2 )
509- REDUCE_CASE (4 )
510- REDUCE_CASE (6 )
511- REDUCE_CASE (8 )
512- default :
513- throw std::runtime_error (
514- " custom allreduce only supports num gpus in (2,4,6,8). Actual num "
515- " gpus = " +
516- std::to_string (world_size_));
517- }
538+ switch (world_size_) {
539+ REDUCE_CASE (2 )
540+ REDUCE_CASE (4 )
541+ REDUCE_CASE (6 )
542+ REDUCE_CASE (8 )
543+ default :
544+ throw std::runtime_error (
545+ " custom allreduce only supports num gpus in (2,4,6,8). Actual num "
546+ " gpus = " +
547+ std::to_string (world_size_));
548+ }
518549#undef REDUCE_CASE
519550#undef KL
520- }
551+ }
521552
522- ~CustomAllreduce () {
523- for (auto [_, ptr] : ipc_handles_) {
524- CUDACHECK (cudaIpcCloseMemHandle (ptr));
525- }
553+ ~CustomAllreduce () {
554+ for (auto [_, ptr] : ipc_handles_) {
555+ CUDACHECK (cudaIpcCloseMemHandle (ptr));
526556 }
527- };
557+ }
558+ }; // namespace vllm
528559/* *
529560 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
530561 a template instantiation:
0 commit comments