44#include < ATen/cuda/CUDAContext.h>
55#include < torch/extension.h>
66
7- template <typename weight_t , int kNStepsPerThread , int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence >
7+ template <typename weight_t , int N>
8+ class UnalignedTuple {
9+ public:
10+ static constexpr int Size = N;
11+ using Type = weight_t ;
12+
13+ weight_t data[N];
14+
15+ __device__ void reverse () {
16+ #pragma unroll
17+ for (int i = 0 ; i < N/2 ; i++) {
18+ weight_t temp = data[i];
19+ data[i] = data[N - (i+1 )];
20+ data[N - (i+1 )] = temp;
21+ }
22+ }
23+ };
24+
25+ template <typename T, int N>
26+ class alignas (16 ) AlignedTuple : public UnalignedTuple<T, N> {
27+ };
28+
29+ template <typename Tuple, int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence >
830__global__ void scan (
9- const weight_t * gates,
10- const weight_t * tokens,
11- weight_t * result,
31+ const Tuple * gates,
32+ const Tuple * tokens,
33+ Tuple * result,
1234 const int batch_stride,
1335 const int dim_stride,
1436 const bool reverse
1537) {
38+ using weight_t = typename Tuple::Type;
39+
1640 __shared__ weight_t warpLastGate[kNWarpsPerBlock ];
1741 __shared__ weight_t warpLastToken[kNWarpsPerBlock ];
1842 __shared__ weight_t chunkAccGate, chunkAccToken;
1943
2044 const int seqoffset = blockIdx .x * batch_stride + blockIdx .y * dim_stride;
2145 const int warpId = threadIdx .x / kNThreadsPerWarp ;
2246 const int laneId = threadIdx .x % kNThreadsPerWarp ;
23- const int chunklen = blockDim .x * kNStepsPerThread ;
47+ const int chunklen = blockDim .x * Tuple::Size ;
2448 constexpr int kBlockLast = kNWarpsPerBlock - 1 ;
2549 constexpr int kWarpLast = kNThreadsPerWarp - 1 ;
26- constexpr int kThreadLast = kNStepsPerThread - 1 ;
50+ constexpr int kThreadLast = Tuple::Size - 1 ;
2751 const weight_t kEmptyGate = 1.0 ;
2852 const weight_t kEmptyToken = 0.0 ;
2953
@@ -32,38 +56,45 @@ __global__ void scan(
3256 // Scan sequentially in thread registers (level 0).
3357 //
3458
35- weight_t accGate[kNStepsPerThread ];
36- weight_t accToken[kNStepsPerThread ];
37-
3859 for (int chunk = 0 ; chunk < kNChunksPerSequence ; chunk++) {
3960 const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
61+ const int tupleOffset = (offset + (reverse ? chunklen - ((threadIdx .x + 1 ) * Tuple::Size) : (threadIdx .x * Tuple::Size))) / Tuple::Size;
4062
4163 if (chunk) {
4264 __syncthreads ();
4365 }
4466
67+ Tuple loadedGate = gates[tupleOffset];
68+ Tuple loadedToken = tokens[tupleOffset];
69+ if (reverse) {
70+ loadedGate.reverse ();
71+ loadedToken.reverse ();
72+ }
73+
74+ Tuple accGate;
75+ Tuple accToken;
76+
4577 #pragma unroll
46- for (int i = 0 ; i < kNStepsPerThread ; ++i) {
47- const int chunkOffset = reverse ? chunklen - 1 - (threadIdx .x * kNStepsPerThread + i) : (threadIdx .x * kNStepsPerThread + i);
48- weight_t gate = gates[offset + chunkOffset];
49- weight_t token = tokens[offset + chunkOffset];
78+ for (int i = 0 ; i < Tuple::Size; ++i) {
79+ weight_t gate = loadedGate.data [i];
80+ weight_t token = loadedToken.data [i];
5081 if (i == 0 ) {
5182 if (chunk == 0 ) {
52- accGate[0 ] = threadIdx .x == 0 ? kEmptyGate : gate;
53- accToken[0 ] = token;
83+ accGate. data [0 ] = threadIdx .x == 0 ? kEmptyGate : gate;
84+ accToken. data [0 ] = token;
5485 } else {
5586 if (threadIdx .x == 0 ) {
5687 // Add the last element of the previous chunk to the first element of the current chunk.
57- accGate[0 ] = chunkAccGate * gate;
58- accToken[0 ] = chunkAccToken * gate + token;
88+ accGate. data [0 ] = chunkAccGate * gate;
89+ accToken. data [0 ] = chunkAccToken * gate + token;
5990 } else {
60- accGate[0 ] = gate;
61- accToken[0 ] = token;
91+ accGate. data [0 ] = gate;
92+ accToken. data [0 ] = token;
6293 }
6394 }
6495 } else {
65- accGate[i] = accGate[i - 1 ] * gate;
66- accToken[i] = accToken[i - 1 ] * gate + token;
96+ accGate. data [i] = accGate. data [i - 1 ] * gate;
97+ accToken. data [i] = accToken. data [i - 1 ] * gate + token;
6798 }
6899 }
69100
@@ -73,14 +104,14 @@ __global__ void scan(
73104
74105 #pragma unroll
75106 for (int delta = 1 ; delta < kNThreadsPerWarp ; delta *= 2 ) {
76- weight_t prev_gate = __shfl_up_sync (0xffffffff , accGate[kThreadLast ], delta);
77- weight_t prev_token = __shfl_up_sync (0xffffffff , accToken[kThreadLast ], delta);
107+ weight_t prev_gate = __shfl_up_sync (0xffffffff , accGate. data [kThreadLast ], delta);
108+ weight_t prev_token = __shfl_up_sync (0xffffffff , accToken. data [kThreadLast ], delta);
78109
79110 if (laneId >= delta) {
80111 #pragma unroll
81- for (int i = 0 ; i < kNStepsPerThread ; ++i) {
82- accToken[i] = prev_token * accGate[i] + accToken[i];
83- accGate[i] = prev_gate * accGate[i];
112+ for (int i = 0 ; i < Tuple::Size ; ++i) {
113+ accToken. data [i] = prev_token * accGate. data [i] + accToken. data [i];
114+ accGate. data [i] = prev_gate * accGate. data [i];
84115 }
85116 }
86117 }
@@ -92,8 +123,8 @@ __global__ void scan(
92123 //
93124
94125 if (laneId == kWarpLast ) {
95- warpLastGate[warpId] = accGate[kThreadLast ];
96- warpLastToken[warpId] = accToken[kThreadLast ];
126+ warpLastGate[warpId] = accGate. data [kThreadLast ];
127+ warpLastToken[warpId] = accToken. data [kThreadLast ];
97128 }
98129
99130 __syncthreads ();
@@ -132,22 +163,46 @@ __global__ void scan(
132163 //
133164
134165 #pragma unroll
135- for (int i = 0 ; i < kNStepsPerThread ; ++i) {
136- const int chunkOffset = reverse ? chunklen - 1 - (threadIdx .x * kNStepsPerThread + i) : (threadIdx .x * kNStepsPerThread + i);
166+ for (int i = 0 ; i < Tuple::Size; ++i) {
137167 if (warpId > 0 ) {
138- accToken[i] = warpLastToken[warpId-1 ] * accGate[i] + accToken[i];
139- accGate[i] = warpLastGate[warpId-1 ] * accGate[i];
168+ accToken. data [i] = warpLastToken[warpId-1 ] * accGate. data [i] + accToken. data [i];
169+ accGate. data [i] = warpLastGate[warpId-1 ] * accGate. data [i];
140170 }
141- result[offset + chunkOffset] = accToken[i];
142171 }
172+ if (reverse) {
173+ accToken.reverse ();
174+ }
175+ result[tupleOffset] = accToken;
143176
144177 if (laneId == kWarpLast && warpId == kBlockLast ) {
145- chunkAccGate = accGate[kThreadLast ];
146- chunkAccToken = accToken[kThreadLast ];
178+ chunkAccGate = accGate. data [kThreadLast ];
179+ chunkAccToken = accToken. data [kThreadLast ];
147180 }
148181 }
149182}
150183
184+ #define DISPATCH_SCAN (weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse ) \
185+ using AlignedT = AlignedTuple<weight_t , kNStepsPerThread >; \
186+ using UnalignedT = UnalignedTuple<weight_t , kNStepsPerThread >; \
187+ if (kNStepsPerThread == 4 && \
188+ ((long )gates.data_ptr()) % 16 == 0 && \
189+ ((long )tokens.data_ptr()) % 16 == 0 && \
190+ ((long )out.data_ptr()) % 16 == 0 ) { \
191+ scan<AlignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
192+ reinterpret_cast <const AlignedT *>(gates.data_ptr <torch_weight_t >()), \
193+ reinterpret_cast <const AlignedT *>(tokens.data_ptr <torch_weight_t >()), \
194+ reinterpret_cast <AlignedT *>(out.data_ptr <torch_weight_t >()), \
195+ batch_stride, dim_stride, reverse \
196+ ); \
197+ } else { \
198+ scan<UnalignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
199+ reinterpret_cast <const UnalignedT*>(gates.data_ptr <torch_weight_t >()), \
200+ reinterpret_cast <const UnalignedT*>(tokens.data_ptr <torch_weight_t >()), \
201+ reinterpret_cast <UnalignedT *>(out.data_ptr <torch_weight_t >()), \
202+ batch_stride, dim_stride, reverse \
203+ ); \
204+ }
205+
151206template <typename weight_t , typename torch_weight_t >
152207void
153208warpscan (const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
@@ -171,109 +226,97 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
171226 constexpr int kNWarpsPerBlock = 1 ;
172227 int kNThreads = seqlen / kNStepsPerThread ;
173228 constexpr int kNChunksPerSequence = 1 ;
174- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
175- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
176- batch_stride, dim_stride, reverse
177- );
229+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
230+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
231+ batch_stride, dim_stride, reverse);
178232 } else if (seqlen == 64 ) {
179233 constexpr int kNStepsPerThread = 2 ;
180234 constexpr int kNWarpsPerBlock = 1 ;
181235 constexpr int kNChunksPerSequence = 1 ;
182236 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
183- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
184- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
185- batch_stride, dim_stride, reverse
186- );
237+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
238+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
239+ batch_stride, dim_stride, reverse);
187240 } else if (seqlen == 128 ) {
188241 constexpr int kNStepsPerThread = 1 ;
189242 constexpr int kNWarpsPerBlock = 4 ;
190243 int kNThreads = seqlen / kNStepsPerThread ;
191244 constexpr int kNChunksPerSequence = 1 ;
192- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
193- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
194- batch_stride, dim_stride, reverse
195- );
245+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
246+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
247+ batch_stride, dim_stride, reverse);
196248 } else if (seqlen == 256 ) {
197249 constexpr int kNStepsPerThread = 1 ;
198250 constexpr int kNWarpsPerBlock = 8 ;
199251 int kNThreads = seqlen / kNStepsPerThread ;
200252 constexpr int kNChunksPerSequence = 1 ;
201- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
202- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
203- batch_stride, dim_stride, reverse
204- );
253+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
254+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
255+ batch_stride, dim_stride, reverse);
205256 } else if (seqlen == 512 ) {
206257 constexpr int kNStepsPerThread = 1 ;
207258 constexpr int kNWarpsPerBlock = 16 ;
208259 int kNThreads = seqlen / kNStepsPerThread ;
209260 constexpr int kNChunksPerSequence = 1 ;
210- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
211- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
212- batch_stride, dim_stride, reverse
213- );
261+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
262+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
263+ batch_stride, dim_stride, reverse);
214264 } else if (seqlen == 1024 ) {
215265 constexpr int kNStepsPerThread = 2 ;
216266 constexpr int kNWarpsPerBlock = 16 ;
217267 int kNThreads = seqlen / kNStepsPerThread ;
218268 constexpr int kNChunksPerSequence = 1 ;
219- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
220- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
221- batch_stride, dim_stride, reverse
222- );
269+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
270+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
271+ batch_stride, dim_stride, reverse);
223272 } else if (seqlen == 2048 ) {
224273 constexpr int kNStepsPerThread = 2 ;
225274 constexpr int kNWarpsPerBlock = 32 ;
226275 int kNThreads = seqlen / kNStepsPerThread ;
227276 constexpr int kNChunksPerSequence = 1 ;
228- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
229- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
230- batch_stride, dim_stride, reverse
231- );
277+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
278+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
279+ batch_stride, dim_stride, reverse);
232280 } else if (seqlen == 4096 ) {
233281 constexpr int kNStepsPerThread = 4 ;
234282 constexpr int kNWarpsPerBlock = 32 ;
235283 int kNThreads = seqlen / kNStepsPerThread ;
236284 constexpr int kNChunksPerSequence = 1 ;
237- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
238- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
239- batch_stride, dim_stride, reverse
240- );
285+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
286+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
287+ batch_stride, dim_stride, reverse);
241288 } else if (seqlen == 8192 ) {
242289 constexpr int kNStepsPerThread = 4 ;
243290 constexpr int kNWarpsPerBlock = 32 ;
244291 constexpr int kNChunksPerSequence = 2 ;
245292 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
246- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
247- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
248- batch_stride, dim_stride, reverse
249- );
293+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
294+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
295+ batch_stride, dim_stride, reverse);
250296 } else if (seqlen == 16384 ) {
251297 constexpr int kNStepsPerThread = 4 ;
252298 constexpr int kNWarpsPerBlock = 32 ;
253299 constexpr int kNChunksPerSequence = 4 ;
254300 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
255- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
256- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
257- batch_stride, dim_stride, reverse
258- );
301+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
302+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
303+ batch_stride, dim_stride, reverse);
259304 } else if (seqlen == 32768 ) {
260305 constexpr int kNStepsPerThread = 4 ;
261306 constexpr int kNWarpsPerBlock = 32 ;
262307 constexpr int kNChunksPerSequence = 8 ;
263308 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
264- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
265- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
266- batch_stride, dim_stride, reverse
267- );
309+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
310+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
311+ batch_stride, dim_stride, reverse);
268312 } else if (seqlen == 65536 ) {
269313 constexpr int kNStepsPerThread = 4 ;
270314 constexpr int kNWarpsPerBlock = 32 ;
271315 constexpr int kNChunksPerSequence = 16 ;
272316 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
273- scan<weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> (
274- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()),
275- batch_stride, dim_stride, reverse
276- );
317+ DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
318+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
319+ batch_stride, dim_stride, reverse);
277320 } else {
278321 TORCH_CHECK (false && " seqlen must be a power of 2, >= 32, <= 65536" );
279322 }
0 commit comments