@@ -61,6 +61,31 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
6161 return __float2bfloat16 (val);
6262}
6363
64+ __device__ float4 loadfloat4 (void const * ptr)
65+ {
66+
67+ float return_value[4 ];
68+
69+ asm volatile (" ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n "
70+ : " =f" (return_value[0 ]), " =f" (return_value[1 ]), " =f" (return_value[2 ]), " =f" (return_value[3 ])
71+ : " l" (ptr));
72+
73+ return *(float4 *) return_value;
74+ }
75+
76+ __device__ __inline__ float2 loadfloat2 (void const * ptr)
77+ {
78+
79+ float return_value[2 ];
80+
81+ asm volatile (" ld.volatile.global.v2.f32 {%0, %1}, [%2];\n "
82+ : " =f" (return_value[0 ]), " =f" (return_value[1 ])
83+ : " l" (ptr)
84+ : " memory" );
85+
86+ return *(float2 *) return_value;
87+ }
88+
6489template <int WORLD_SIZE, typename T>
6590__global__ void twoshot_allreduce_kernel (T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
6691 int buffer_M, int token_dim, int rank, uint32_t * buffer_flags, bool wait_for_results)
@@ -74,20 +99,13 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
7499 cudaGridDependencySynchronize ();
75100#endif
76101
102+ // [input_ptr, clear_ptr, buffer_size, access_counter]
103+ uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
104+ // Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
105+ uint32_t buffer_group_size = flag.z << 1 ;
106+ uint32_t input_offset = flag.x * buffer_group_size;
107+ uint32_t clear_offset = flag.y * buffer_group_size;
77108 uint32_t * offset_access_ptr = &buffer_flags[3 ];
78- // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
79- uint32_t buffer_size = (buffer_flags[2 ] << 1 );
80- uint32_t input_offset = buffer_flags[0 ] * buffer_size;
81- uint32_t clear_offset = buffer_flags[1 ] * buffer_size;
82-
83- if (wait_for_results)
84- {
85- __syncthreads ();
86- if (threadIdx .x == 0 )
87- {
88- atomicAdd (offset_access_ptr, 1 );
89- }
90- }
91109
92110 if (elt < token_dim)
93111 {
@@ -101,17 +119,16 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
101119
102120 // Reduce and broadcast
103121
104- int global_token = token * WORLD_SIZE + rank;
105- if (global_token < num_tokens)
122+ if ((token % WORLD_SIZE) == rank)
106123 {
107-
124+ int local_token = token / WORLD_SIZE;
108125 float accum = 0 .f ;
109126
110127 T values[WORLD_SIZE];
111128
112129 for (int r = 0 ; r < WORLD_SIZE; r++)
113130 {
114- input_ptrs[rank][clear_offset + token * token_dim * WORLD_SIZE + r * token_dim + elt]
131+ input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]
115132 = fromFloat<T>(-0 .f );
116133 }
117134
@@ -121,7 +138,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
121138 for (int r = 0 ; r < WORLD_SIZE; r++)
122139 {
123140 T volatile * lamport_ptr = (T volatile *) &input_ptrs[rank][input_offset
124- + token * token_dim * WORLD_SIZE + r * token_dim + elt];
141+ + local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
125142 values[r] = *lamport_ptr;
126143 valid &= !isNegZero (values[r]);
127144 }
@@ -132,7 +149,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
132149 {
133150 accum += toFloat<T>(values[r]);
134151 }
135- mcast_ptr[input_offset + buffer_M * token_dim + global_token * token_dim + elt] = fromFloat<T>(accum);
152+ mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
136153 }
137154 }
138155
@@ -145,23 +162,50 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
145162 // Optionally wait for results if the next layer isn't doing the Lamport check
146163 if (wait_for_results)
147164 {
148- T volatile * lamport_ptr
149- = (T volatile *) &input_ptrs[rank][input_offset + buffer_M * token_dim + token * token_dim + elt];
150- T val = *lamport_ptr;
151- while (isNegZero (val))
152- val = *lamport_ptr;
153-
154- // Copy if requested
155- if (output_ptr)
156- output_ptr[token * token_dim + elt] = val;
157- if (threadIdx .x == 0 && blockIdx .x == 0 && blockIdx .y == 0 )
165+ // Update the atomic counter to indicate the block has read the offsets
166+ __syncthreads ();
167+
168+ if (threadIdx .x == 0 )
169+ {
170+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
171+ asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
172+ #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
173+ asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
174+ #else
175+ atomicAdd (offset_access_ptr, 1 );
176+ #endif
177+ }
178+ // Only use a set of CTAs for lamport sync, reargange the grid
179+ constexpr int ELTS_PER_LOAD = sizeof (float2 ) / sizeof (T);
180+ // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
181+ if (threadIdx .x < (blockDim .x / ELTS_PER_LOAD))
182+ {
183+ uint64_t current_pos = blockIdx .x * token_dim + blockIdx .y * blockDim .x + threadIdx .x * ELTS_PER_LOAD;
184+
185+ void * lamport_ptr = (void *) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
186+ // We have 2 assumptions here:
187+ // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
188+ // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
189+ float2 val = loadfloat2 (lamport_ptr);
190+ while (isNegZero (*(T*) &val))
191+ {
192+ val = loadfloat2 (lamport_ptr);
193+ }
194+ if (output_ptr)
195+ {
196+ *((float2 *) &output_ptr[current_pos]) = val;
197+ }
198+ }
199+
200+ // Update the buffer flags
201+ if (threadIdx .x == 0 && blockIdx .x == gridDim .x - 1 && blockIdx .y == 0 )
158202 {
159203 // Make sure all blocks have finished reading the offsets, 2-D grid
160204 while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
161205 {
162206 }
163- buffer_flags[0 ] = (buffer_flags[ 0 ] + 1 ) % 3 ;
164- buffer_flags[1 ] = (buffer_flags[ 1 ] + 1 ) % 3 ;
207+ buffer_flags[0 ] = (flag. x + 1 ) % 3 ;
208+ buffer_flags[1 ] = (flag. y + 1 ) % 3 ;
165209 *(offset_access_ptr) = 0 ;
166210 }
167211 }
@@ -251,18 +295,6 @@ __device__ void copy_f4_ldg(T_IN* dst, T_IN const* src)
251295 *dst4 = *src4;
252296}
253297
254- __device__ float4 loadfloat4 (void const * ptr)
255- {
256-
257- float return_value[4 ];
258-
259- asm volatile (" ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n "
260- : " =f" (return_value[0 ]), " =f" (return_value[1 ]), " =f" (return_value[2 ]), " =f" (return_value[3 ])
261- : " l" (ptr));
262-
263- return *(float4 *) return_value;
264- }
265-
266298template <typename T>
267299inline __device__ T add (T a, T b)
268300{
@@ -322,19 +354,14 @@ __global__ void __launch_bounds__(128, 1)
322354 int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
323355
324356 uint32_t * offset_access_ptr = &buffer_flags[3 ];
357+ uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
325358 // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
326- uint32_t buffer_size = buffer_flags[ 2 ] ;
327- uint32_t buffer_offset = buffer_flags[ 0 ] * (buffer_size << 1 );
359+ uint32_t buffer_size = flag. z ;
360+ uint32_t buffer_offset = flag. x * (buffer_size << 1 );
328361 T_IN const * input = &buffer_input[buffer_offset + buffer_size];
329362
330363 cudaTriggerProgrammaticLaunchCompletion ();
331364
332- __syncthreads ();
333- if (threadIdx .x == 0 )
334- {
335- atomicAdd (offset_access_ptr, 1 );
336- }
337-
338365 for (int i = 0 ; i < NUM_INPUTS; i++)
339366 {
340367 for (int j = 0 ; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++)
@@ -361,7 +388,17 @@ __global__ void __launch_bounds__(128, 1)
361388 }
362389
363390 __pipeline_commit ();
364-
391+ __syncthreads ();
392+ if (threadIdx .x == 0 )
393+ {
394+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
395+ asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
396+ #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
397+ asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
398+ #else
399+ atomicAdd (offset_access_ptr, 1 );
400+ #endif
401+ }
365402 // Load all inputs
366403 bool valid = false ;
367404
@@ -494,14 +531,13 @@ __global__ void __launch_bounds__(128, 1)
494531 if (threadIdx .x == 0 && blockIdx .x == 0 && blockIdx .y == 0 )
495532 {
496533 // Make sure all blocks have finished accessing the buffer
497- while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) != gridDim .x * gridDim .y )
534+ while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
498535 {
499536 }
500- buffer_flags[0 ] = (buffer_flags[ 0 ] + 1 ) % 3 ;
501- buffer_flags[1 ] = (buffer_flags[ 1 ] + 1 ) % 3 ;
537+ buffer_flags[0 ] = (flag. x + 1 ) % 3 ;
538+ buffer_flags[1 ] = (flag. y + 1 ) % 3 ;
502539 *(offset_access_ptr) = 0 ;
503540 }
504- __syncthreads ();
505541#endif
506542}
507543
0 commit comments