44#include < ATen/cuda/CUDAContext.h>
55#include < torch/extension.h>
66
7+ template <typename T>
8+ struct alignas (16 ) AlignedQuad {
9+ T x;
10+ T y;
11+ T z;
12+ T w;
13+
14+ __device__ AlignedQuad (T x, T y, T z, T w) : x (x), y (y), z (z), w (w) {}
15+
16+ __device__ void reverse () {
17+ // Swap x and w
18+ T temp = x;
19+ x = w;
20+ w = temp;
21+
22+ // Swap y and z
23+ temp = y;
24+ y = z;
25+ z = temp;
26+ }
27+ };
28+
729template <typename weight_t , int kNStepsPerThread , int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence >
8- __global__ void scan (
30+ __global__ void scanunaligned (
931 const weight_t * gates,
1032 const weight_t * tokens,
1133 weight_t * result,
@@ -148,6 +170,174 @@ __global__ void scan(
148170 }
149171}
150172
173+ template <typename weight_t , int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence >
174+ __global__ void scanaligned (
175+ const weight_t * gates,
176+ const weight_t * tokens,
177+ weight_t * result,
178+ const int batch_stride,
179+ const int dim_stride,
180+ const bool reverse
181+ ) {
182+ __shared__ weight_t warpLastGate[kNWarpsPerBlock ];
183+ __shared__ weight_t warpLastToken[kNWarpsPerBlock ];
184+ __shared__ weight_t chunkAccGate, chunkAccToken;
185+
186+ const int seqoffset = blockIdx .x * batch_stride + blockIdx .y * dim_stride;
187+ const int warpId = threadIdx .x / kNThreadsPerWarp ;
188+ const int laneId = threadIdx .x % kNThreadsPerWarp ;
189+ const int chunklen = blockDim .x * 4 ;
190+ constexpr int kBlockLast = kNWarpsPerBlock - 1 ;
191+ constexpr int kWarpLast = kNThreadsPerWarp - 1 ;
192+ constexpr int kThreadLast = 3 ;
193+ const weight_t kEmptyGate = 1.0 ;
194+ const weight_t kEmptyToken = 0.0 ;
195+
196+ //
197+ // Read from global memory.
198+ // Scan sequentially in thread registers (level 0).
199+ //
200+
201+ weight_t accGate[4 ];
202+ weight_t accToken[4 ];
203+
204+ for (int chunk = 0 ; chunk < kNChunksPerSequence ; chunk++) {
205+ const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
206+ const int quadOffset = (offset + (reverse ? chunklen - ((threadIdx .x + 1 ) * 4 ) : (threadIdx .x * 4 ))) / 4 ;
207+
208+ if (chunk) {
209+ __syncthreads ();
210+ }
211+
212+ AlignedQuad<weight_t > loadedGate = reinterpret_cast <const struct AlignedQuad <weight_t > *>(gates)[quadOffset];
213+ AlignedQuad<weight_t > loadedToken = reinterpret_cast <const struct AlignedQuad <weight_t > *>(tokens)[quadOffset];
214+ if (reverse) {
215+ loadedGate.reverse ();
216+ loadedToken.reverse ();
217+ }
218+
219+ if (chunk == 0 ) {
220+ accGate[0 ] = threadIdx .x == 0 ? kEmptyGate : loadedGate.x ;
221+ accToken[0 ] = loadedToken.x ;
222+ } else {
223+ if (threadIdx .x == 0 ) {
224+ accGate[0 ] = chunkAccGate * loadedGate.x ;
225+ accToken[0 ] = chunkAccToken * loadedGate.x + loadedToken.x ;
226+ } else {
227+ accGate[0 ] = loadedGate.x ;
228+ accToken[0 ] = loadedToken.x ;
229+ }
230+ }
231+ accGate[1 ] = accGate[0 ] * loadedGate.y ;
232+ accGate[2 ] = accGate[1 ] * loadedGate.z ;
233+ accGate[3 ] = accGate[2 ] * loadedGate.w ;
234+ accToken[1 ] = loadedToken.y + accToken[0 ] * loadedGate.y ;
235+ accToken[2 ] = loadedToken.z + accToken[1 ] * loadedGate.z ;
236+ accToken[3 ] = loadedToken.w + accToken[2 ] * loadedGate.w ;
237+
238+ //
239+ // Scan threads in a warp using shuffling (level 1).
240+ //
241+
242+ #pragma unroll
243+ for (int delta = 1 ; delta < kNThreadsPerWarp ; delta *= 2 ) {
244+ weight_t prev_gate = __shfl_up_sync (0xffffffff , accGate[kThreadLast ], delta);
245+ weight_t prev_token = __shfl_up_sync (0xffffffff , accToken[kThreadLast ], delta);
246+
247+ if (laneId >= delta) {
248+ #pragma unroll
249+ for (int i = 0 ; i < 4 ; ++i) {
250+ accToken[i] = prev_token * accGate[i] + accToken[i];
251+ accGate[i] = prev_gate * accGate[i];
252+ }
253+ }
254+ }
255+
256+ __syncwarp ();
257+
258+ //
259+ // Store the last element of each warp in shared memory.
260+ //
261+
262+ if (laneId == kWarpLast ) {
263+ warpLastGate[warpId] = accGate[kThreadLast ];
264+ warpLastToken[warpId] = accToken[kThreadLast ];
265+ }
266+
267+ __syncthreads ();
268+
269+ //
270+ // Leading warp scans every warp in a block (level 2).
271+ //
272+
273+ if (warpId == 0 ) {
274+ weight_t warpAccGate, warpAccToken;
275+ warpAccGate = (laneId < kNWarpsPerBlock ) ? warpLastGate[laneId] : kEmptyGate ;
276+ warpAccToken = (laneId < kNWarpsPerBlock ) ? warpLastToken[laneId] : kEmptyToken ;
277+
278+ #pragma unroll
279+ for (int delta = 1 ; delta < warpSize ; delta *= 2 ) {
280+ weight_t prev_gate = __shfl_up_sync (0xffffffff , warpAccGate, delta);
281+ weight_t prev_token = __shfl_up_sync (0xffffffff , warpAccToken, delta);
282+
283+ if (laneId >= delta) {
284+ warpAccToken = prev_token * warpAccGate + warpAccToken;
285+ warpAccGate = prev_gate * warpAccGate;
286+ }
287+ }
288+
289+ if (laneId < kNWarpsPerBlock ) {
290+ warpLastGate[laneId] = warpAccGate;
291+ warpLastToken[laneId] = warpAccToken;
292+ }
293+ }
294+
295+ __syncthreads ();
296+
297+ //
298+ // Add the last element of the previous warp to each element of the current warp (level 0).
299+ // Store to global memory.
300+ //
301+
302+ #pragma unroll
303+ for (int i = 0 ; i < 4 ; ++i) {
304+ if (warpId > 0 ) {
305+ accToken[i] = warpLastToken[warpId-1 ] * accGate[i] + accToken[i];
306+ accGate[i] = warpLastGate[warpId-1 ] * accGate[i];
307+ }
308+ }
309+
310+ AlignedQuad<weight_t > outAccToken (accToken[0 ], accToken[1 ], accToken[2 ], accToken[3 ]);
311+ if (reverse) {
312+ outAccToken.reverse ();
313+ }
314+ reinterpret_cast <struct AlignedQuad <weight_t > *>(result)[quadOffset] = outAccToken;
315+
316+ if (laneId == kWarpLast && warpId == kBlockLast ) {
317+ chunkAccGate = accGate[kThreadLast ];
318+ chunkAccToken = accToken[kThreadLast ];
319+ }
320+ }
321+ }
322+
323+ #define DISPATCH_SCAN (weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse ) \
324+ if (kNStepsPerThread == 4 && sizeof (weight_t ) <= 4 && ((long )gates.data_ptr()) % 16 == 0 && \
325+ ((long )tokens.data_ptr()) % 16 == 0 && ((long )out.data_ptr()) % 16 == 0 ) { \
326+ scanaligned<weight_t , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
327+ reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), \
328+ reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), \
329+ reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()), \
330+ batch_stride, dim_stride, reverse \
331+ ); \
332+ } else { \
333+ scanunaligned<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
334+ reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), \
335+ reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), \
336+ reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()), \
337+ batch_stride, dim_stride, reverse \
338+ ); \
339+ }
340+
151341template <typename weight_t , typename torch_weight_t >
152342void
153343warpscan (const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
@@ -171,109 +361,97 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
171361 constexpr int kNWarpsPerBlock = 1 ;
172362 int kNThreads = seqlen / kNStepsPerThread ;
173363 constexpr int kNChunksPerSequence = 1 ;
174- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
175- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
176- batch_stride, dim_stride, reverse
177- );
364+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
365+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
366+ batch_stride, dim_stride, reverse);
178367 } else if (seqlen == 64 ) {
179368 constexpr int kNStepsPerThread = 2 ;
180369 constexpr int kNWarpsPerBlock = 1 ;
181370 constexpr int kNChunksPerSequence = 1 ;
182371 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
183- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
184- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
185- batch_stride, dim_stride, reverse
186- );
372+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
373+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
374+ batch_stride, dim_stride, reverse);
187375 } else if (seqlen == 128 ) {
188376 constexpr int kNStepsPerThread = 1 ;
189377 constexpr int kNWarpsPerBlock = 4 ;
190378 int kNThreads = seqlen / kNStepsPerThread ;
191379 constexpr int kNChunksPerSequence = 1 ;
192- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
193- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
194- batch_stride, dim_stride, reverse
195- );
380+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
381+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
382+ batch_stride, dim_stride, reverse);
196383 } else if (seqlen == 256 ) {
197384 constexpr int kNStepsPerThread = 1 ;
198385 constexpr int kNWarpsPerBlock = 8 ;
199386 int kNThreads = seqlen / kNStepsPerThread ;
200387 constexpr int kNChunksPerSequence = 1 ;
201- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
202- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
203- batch_stride, dim_stride, reverse
204- );
388+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
389+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
390+ batch_stride, dim_stride, reverse);
205391 } else if (seqlen == 512 ) {
206392 constexpr int kNStepsPerThread = 1 ;
207393 constexpr int kNWarpsPerBlock = 16 ;
208394 int kNThreads = seqlen / kNStepsPerThread ;
209395 constexpr int kNChunksPerSequence = 1 ;
210- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
211- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
212- batch_stride, dim_stride, reverse
213- );
396+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
397+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
398+ batch_stride, dim_stride, reverse);
214399 } else if (seqlen == 1024 ) {
215400 constexpr int kNStepsPerThread = 2 ;
216401 constexpr int kNWarpsPerBlock = 16 ;
217402 int kNThreads = seqlen / kNStepsPerThread ;
218403 constexpr int kNChunksPerSequence = 1 ;
219- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
220- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
221- batch_stride, dim_stride, reverse
222- );
404+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
405+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
406+ batch_stride, dim_stride, reverse);
223407 } else if (seqlen == 2048 ) {
224408 constexpr int kNStepsPerThread = 2 ;
225409 constexpr int kNWarpsPerBlock = 32 ;
226410 int kNThreads = seqlen / kNStepsPerThread ;
227411 constexpr int kNChunksPerSequence = 1 ;
228- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
229- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
230- batch_stride, dim_stride, reverse
231- );
412+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
413+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
414+ batch_stride, dim_stride, reverse);
232415 } else if (seqlen == 4096 ) {
233416 constexpr int kNStepsPerThread = 4 ;
234417 constexpr int kNWarpsPerBlock = 32 ;
235418 int kNThreads = seqlen / kNStepsPerThread ;
236419 constexpr int kNChunksPerSequence = 1 ;
237- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
238- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
239- batch_stride, dim_stride, reverse
240- );
420+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
421+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
422+ batch_stride, dim_stride, reverse);
241423 } else if (seqlen == 8192 ) {
242424 constexpr int kNStepsPerThread = 4 ;
243425 constexpr int kNWarpsPerBlock = 32 ;
244426 constexpr int kNChunksPerSequence = 2 ;
245427 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
246- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
247- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
248- batch_stride, dim_stride, reverse
249- );
428+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
429+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
430+ batch_stride, dim_stride, reverse);
250431 } else if (seqlen == 16384 ) {
251432 constexpr int kNStepsPerThread = 4 ;
252433 constexpr int kNWarpsPerBlock = 32 ;
253434 constexpr int kNChunksPerSequence = 4 ;
254435 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
255- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
256- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
257- batch_stride, dim_stride, reverse
258- );
436+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
437+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
438+ batch_stride, dim_stride, reverse);
259439 } else if (seqlen == 32768 ) {
260440 constexpr int kNStepsPerThread = 4 ;
261441 constexpr int kNWarpsPerBlock = 32 ;
262442 constexpr int kNChunksPerSequence = 8 ;
263443 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
264- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
265- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
266- batch_stride, dim_stride, reverse
267- );
444+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
445+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
446+ batch_stride, dim_stride, reverse);
268447 } else if (seqlen == 65536 ) {
269448 constexpr int kNStepsPerThread = 4 ;
270449 constexpr int kNWarpsPerBlock = 32 ;
271450 constexpr int kNChunksPerSequence = 16 ;
272451 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
273- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
274- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
275- batch_stride, dim_stride, reverse
276- );
452+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
453+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
454+ batch_stride, dim_stride, reverse);
277455 } else {
278456 TORCH_CHECK (false && " seqlen must be a power of 2, >= 32, <= 65536" );
279457 }
0 commit comments