44#include < ATen/cuda/CUDAContext.h>
55#include < torch/extension.h>
66
7- template <typename T>
8- struct alignas (16 ) AlignedQuad {
9- T x;
10- T y;
11- T z;
12- T w;
7+ template <typename weight_t , int N>
8+ class UnalignedTuple {
9+ public:
10+ static constexpr int Size = N;
11+ using Type = weight_t ;
1312
14- __device__ AlignedQuad (T x, T y, T z, T w) : x (x), y (y), z (z), w (w) {}
13+ weight_t data[N];
1514
1615 __device__ void reverse () {
17- // Swap x and w
18- T temp = x;
19- x = w;
20- w = temp;
21-
22- // Swap y and z
23- temp = y;
24- y = z;
25- z = temp;
16+ #pragma unroll
17+ for (int i = 0 ; i < N/2 ; i++) {
18+ weight_t temp = data[i];
19+ data[i] = data[N - (i+1 )];
20+ data[N - (i+1 )] = temp;
21+ }
2622 }
2723};
2824
29- template <typename weight_t , int kNStepsPerThread , int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence >
30- __global__ void scanunaligned (
31- const weight_t * gates,
32- const weight_t * tokens,
33- weight_t * result,
25+ template <typename T, int N>
26+ class alignas (16 ) AlignedTuple : public UnalignedTuple<T, N> {
27+ };
28+
29+ template <typename Tuple, int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence >
30+ __global__ void scan (
31+ const Tuple* gates,
32+ const Tuple* tokens,
33+ Tuple* result,
3434 const int batch_stride,
3535 const int dim_stride,
3636 const bool reverse
3737) {
38+ using weight_t = typename Tuple::Type;
39+
3840 __shared__ weight_t warpLastGate[kNWarpsPerBlock ];
3941 __shared__ weight_t warpLastToken[kNWarpsPerBlock ];
4042 __shared__ weight_t chunkAccGate, chunkAccToken;
4143
4244 const int seqoffset = blockIdx .x * batch_stride + blockIdx .y * dim_stride;
4345 const int warpId = threadIdx .x / kNThreadsPerWarp ;
4446 const int laneId = threadIdx .x % kNThreadsPerWarp ;
45- const int chunklen = blockDim .x * kNStepsPerThread ;
47+ const int chunklen = blockDim .x * Tuple::Size ;
4648 constexpr int kBlockLast = kNWarpsPerBlock - 1 ;
4749 constexpr int kWarpLast = kNThreadsPerWarp - 1 ;
48- constexpr int kThreadLast = kNStepsPerThread - 1 ;
50+ constexpr int kThreadLast = Tuple::Size - 1 ;
4951 const weight_t kEmptyGate = 1.0 ;
5052 const weight_t kEmptyToken = 0.0 ;
5153
@@ -54,201 +56,62 @@ __global__ void scanunaligned(
5456 // Scan sequentially in thread registers (level 0).
5557 //
5658
57- weight_t accGate[kNStepsPerThread ];
58- weight_t accToken[kNStepsPerThread ];
59-
6059 for (int chunk = 0 ; chunk < kNChunksPerSequence ; chunk++) {
6160 const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
61+ const int tupleOffset = (offset + (reverse ? chunklen - ((threadIdx .x + 1 ) * Tuple::Size) : (threadIdx .x * Tuple::Size))) / Tuple::Size;
6262
6363 if (chunk) {
6464 __syncthreads ();
6565 }
6666
67+ Tuple loadedGate = gates[tupleOffset];
68+ Tuple loadedToken = tokens[tupleOffset];
69+ if (reverse) {
70+ loadedGate.reverse ();
71+ loadedToken.reverse ();
72+ }
73+
74+ Tuple accGate;
75+ Tuple accToken;
76+
6777 #pragma unroll
68- for (int i = 0 ; i < kNStepsPerThread ; ++i) {
69- const int chunkOffset = reverse ? chunklen - 1 - (threadIdx .x * kNStepsPerThread + i) : (threadIdx .x * kNStepsPerThread + i);
70- weight_t gate = gates[offset + chunkOffset];
71- weight_t token = tokens[offset + chunkOffset];
78+ for (int i = 0 ; i < Tuple::Size; ++i) {
79+ weight_t gate = loadedGate.data [i];
80+ weight_t token = loadedToken.data [i];
7281 if (i == 0 ) {
7382 if (chunk == 0 ) {
74- accGate[0 ] = threadIdx .x == 0 ? kEmptyGate : gate;
75- accToken[0 ] = token;
83+ accGate. data [0 ] = threadIdx .x == 0 ? kEmptyGate : gate;
84+ accToken. data [0 ] = token;
7685 } else {
7786 if (threadIdx .x == 0 ) {
7887 // Add the last element of the previous chunk to the first element of the current chunk.
79- accGate[0 ] = chunkAccGate * gate;
80- accToken[0 ] = chunkAccToken * gate + token;
88+ accGate. data [0 ] = chunkAccGate * gate;
89+ accToken. data [0 ] = chunkAccToken * gate + token;
8190 } else {
82- accGate[0 ] = gate;
83- accToken[0 ] = token;
91+ accGate. data [0 ] = gate;
92+ accToken. data [0 ] = token;
8493 }
8594 }
8695 } else {
87- accGate[i] = accGate[i - 1 ] * gate;
88- accToken[i] = accToken[i - 1 ] * gate + token;
89- }
90- }
91-
92- //
93- // Scan threads in a warp using shuffling (level 1).
94- //
95-
96- #pragma unroll
97- for (int delta = 1 ; delta < kNThreadsPerWarp ; delta *= 2 ) {
98- weight_t prev_gate = __shfl_up_sync (0xffffffff , accGate[kThreadLast ], delta);
99- weight_t prev_token = __shfl_up_sync (0xffffffff , accToken[kThreadLast ], delta);
100-
101- if (laneId >= delta) {
102- #pragma unroll
103- for (int i = 0 ; i < kNStepsPerThread ; ++i) {
104- accToken[i] = prev_token * accGate[i] + accToken[i];
105- accGate[i] = prev_gate * accGate[i];
106- }
107- }
108- }
109-
110- __syncwarp ();
111-
112- //
113- // Store the last element of each warp in shared memory.
114- //
115-
116- if (laneId == kWarpLast ) {
117- warpLastGate[warpId] = accGate[kThreadLast ];
118- warpLastToken[warpId] = accToken[kThreadLast ];
119- }
120-
121- __syncthreads ();
122-
123- //
124- // Leading warp scans every warp in a block (level 2).
125- //
126-
127- if (warpId == 0 ) {
128- weight_t warpAccGate, warpAccToken;
129- warpAccGate = (laneId < kNWarpsPerBlock ) ? warpLastGate[laneId] : kEmptyGate ;
130- warpAccToken = (laneId < kNWarpsPerBlock ) ? warpLastToken[laneId] : kEmptyToken ;
131-
132- #pragma unroll
133- for (int delta = 1 ; delta < warpSize ; delta *= 2 ) {
134- weight_t prev_gate = __shfl_up_sync (0xffffffff , warpAccGate, delta);
135- weight_t prev_token = __shfl_up_sync (0xffffffff , warpAccToken, delta);
136-
137- if (laneId >= delta) {
138- warpAccToken = prev_token * warpAccGate + warpAccToken;
139- warpAccGate = prev_gate * warpAccGate;
140- }
141- }
142-
143- if (laneId < kNWarpsPerBlock ) {
144- warpLastGate[laneId] = warpAccGate;
145- warpLastToken[laneId] = warpAccToken;
146- }
147- }
148-
149- __syncthreads ();
150-
151- //
152- // Add the last element of the previous warp to each element of the current warp (level 0).
153- // Store to global memory.
154- //
155-
156- #pragma unroll
157- for (int i = 0 ; i < kNStepsPerThread ; ++i) {
158- const int chunkOffset = reverse ? chunklen - 1 - (threadIdx .x * kNStepsPerThread + i) : (threadIdx .x * kNStepsPerThread + i);
159- if (warpId > 0 ) {
160- accToken[i] = warpLastToken[warpId-1 ] * accGate[i] + accToken[i];
161- accGate[i] = warpLastGate[warpId-1 ] * accGate[i];
162- }
163- result[offset + chunkOffset] = accToken[i];
164- }
165-
166- if (laneId == kWarpLast && warpId == kBlockLast ) {
167- chunkAccGate = accGate[kThreadLast ];
168- chunkAccToken = accToken[kThreadLast ];
169- }
170- }
171- }
172-
173- template <typename weight_t , int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence >
174- __global__ void scanaligned (
175- const weight_t * gates,
176- const weight_t * tokens,
177- weight_t * result,
178- const int batch_stride,
179- const int dim_stride,
180- const bool reverse
181- ) {
182- __shared__ weight_t warpLastGate[kNWarpsPerBlock ];
183- __shared__ weight_t warpLastToken[kNWarpsPerBlock ];
184- __shared__ weight_t chunkAccGate, chunkAccToken;
185-
186- const int seqoffset = blockIdx .x * batch_stride + blockIdx .y * dim_stride;
187- const int warpId = threadIdx .x / kNThreadsPerWarp ;
188- const int laneId = threadIdx .x % kNThreadsPerWarp ;
189- const int chunklen = blockDim .x * 4 ;
190- constexpr int kBlockLast = kNWarpsPerBlock - 1 ;
191- constexpr int kWarpLast = kNThreadsPerWarp - 1 ;
192- constexpr int kThreadLast = 3 ;
193- const weight_t kEmptyGate = 1.0 ;
194- const weight_t kEmptyToken = 0.0 ;
195-
196- //
197- // Read from global memory.
198- // Scan sequentially in thread registers (level 0).
199- //
200-
201- weight_t accGate[4 ];
202- weight_t accToken[4 ];
203-
204- for (int chunk = 0 ; chunk < kNChunksPerSequence ; chunk++) {
205- const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
206- const int quadOffset = (offset + (reverse ? chunklen - ((threadIdx .x + 1 ) * 4 ) : (threadIdx .x * 4 ))) / 4 ;
207-
208- if (chunk) {
209- __syncthreads ();
210- }
211-
212- AlignedQuad<weight_t > loadedGate = reinterpret_cast <const struct AlignedQuad <weight_t > *>(gates)[quadOffset];
213- AlignedQuad<weight_t > loadedToken = reinterpret_cast <const struct AlignedQuad <weight_t > *>(tokens)[quadOffset];
214- if (reverse) {
215- loadedGate.reverse ();
216- loadedToken.reverse ();
217- }
218-
219- if (chunk == 0 ) {
220- accGate[0 ] = threadIdx .x == 0 ? kEmptyGate : loadedGate.x ;
221- accToken[0 ] = loadedToken.x ;
222- } else {
223- if (threadIdx .x == 0 ) {
224- accGate[0 ] = chunkAccGate * loadedGate.x ;
225- accToken[0 ] = chunkAccToken * loadedGate.x + loadedToken.x ;
226- } else {
227- accGate[0 ] = loadedGate.x ;
228- accToken[0 ] = loadedToken.x ;
96+ accGate.data [i] = accGate.data [i - 1 ] * gate;
97+ accToken.data [i] = accToken.data [i - 1 ] * gate + token;
22998 }
23099 }
231- accGate[1 ] = accGate[0 ] * loadedGate.y ;
232- accGate[2 ] = accGate[1 ] * loadedGate.z ;
233- accGate[3 ] = accGate[2 ] * loadedGate.w ;
234- accToken[1 ] = loadedToken.y + accToken[0 ] * loadedGate.y ;
235- accToken[2 ] = loadedToken.z + accToken[1 ] * loadedGate.z ;
236- accToken[3 ] = loadedToken.w + accToken[2 ] * loadedGate.w ;
237100
238101 //
239102 // Scan threads in a warp using shuffling (level 1).
240103 //
241104
242105 #pragma unroll
243106 for (int delta = 1 ; delta < kNThreadsPerWarp ; delta *= 2 ) {
244- weight_t prev_gate = __shfl_up_sync (0xffffffff , accGate[kThreadLast ], delta);
245- weight_t prev_token = __shfl_up_sync (0xffffffff , accToken[kThreadLast ], delta);
107+ weight_t prev_gate = __shfl_up_sync (0xffffffff , accGate. data [kThreadLast ], delta);
108+ weight_t prev_token = __shfl_up_sync (0xffffffff , accToken. data [kThreadLast ], delta);
246109
247110 if (laneId >= delta) {
248111 #pragma unroll
249- for (int i = 0 ; i < 4 ; ++i) {
250- accToken[i] = prev_token * accGate[i] + accToken[i];
251- accGate[i] = prev_gate * accGate[i];
112+ for (int i = 0 ; i < Tuple::Size ; ++i) {
113+ accToken. data [i] = prev_token * accGate. data [i] + accToken. data [i];
114+ accGate. data [i] = prev_gate * accGate. data [i];
252115 }
253116 }
254117 }
@@ -260,8 +123,8 @@ __global__ void scanaligned(
260123 //
261124
262125 if (laneId == kWarpLast ) {
263- warpLastGate[warpId] = accGate[kThreadLast ];
264- warpLastToken[warpId] = accToken[kThreadLast ];
126+ warpLastGate[warpId] = accGate. data [kThreadLast ];
127+ warpLastToken[warpId] = accToken. data [kThreadLast ];
265128 }
266129
267130 __syncthreads ();
@@ -300,40 +163,42 @@ __global__ void scanaligned(
300163 //
301164
302165 #pragma unroll
303- for (int i = 0 ; i < 4 ; ++i) {
166+ for (int i = 0 ; i < Tuple::Size ; ++i) {
304167 if (warpId > 0 ) {
305- accToken[i] = warpLastToken[warpId-1 ] * accGate[i] + accToken[i];
306- accGate[i] = warpLastGate[warpId-1 ] * accGate[i];
168+ accToken. data [i] = warpLastToken[warpId-1 ] * accGate. data [i] + accToken. data [i];
169+ accGate. data [i] = warpLastGate[warpId-1 ] * accGate. data [i];
307170 }
308171 }
309-
310- AlignedQuad<weight_t > outAccToken (accToken[0 ], accToken[1 ], accToken[2 ], accToken[3 ]);
311172 if (reverse) {
312- outAccToken .reverse ();
173+ accToken .reverse ();
313174 }
314- reinterpret_cast < struct AlignedQuad < weight_t > *>( result)[quadOffset ] = outAccToken ;
175+ result[tupleOffset ] = accToken ;
315176
316177 if (laneId == kWarpLast && warpId == kBlockLast ) {
317- chunkAccGate = accGate[kThreadLast ];
318- chunkAccToken = accToken[kThreadLast ];
178+ chunkAccGate = accGate. data [kThreadLast ];
179+ chunkAccToken = accToken. data [kThreadLast ];
319180 }
320181 }
321182}
322183
323184#define DISPATCH_SCAN (weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse ) \
324- if (kNStepsPerThread == 4 && sizeof (weight_t ) <= 4 && ((long )gates.data_ptr()) % 16 == 0 && \
325- ((long )tokens.data_ptr()) % 16 == 0 && ((long )out.data_ptr()) % 16 == 0 ) { \
326- scanaligned<weight_t , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
327- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), \
328- reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), \
329- reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()), \
185+ using AlignedT = AlignedTuple<weight_t , kNStepsPerThread >; \
186+ using UnalignedT = UnalignedTuple<weight_t , kNStepsPerThread >; \
187+ if (kNStepsPerThread == 4 && \
188+ ((long )gates.data_ptr()) % 16 == 0 && \
189+ ((long )tokens.data_ptr()) % 16 == 0 && \
190+ ((long )out.data_ptr()) % 16 == 0 ) { \
191+ scan<AlignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
192+ reinterpret_cast <const AlignedT *>(gates.data_ptr <torch_weight_t >()), \
193+ reinterpret_cast <const AlignedT *>(tokens.data_ptr <torch_weight_t >()), \
194+ reinterpret_cast <AlignedT *>(out.data_ptr <torch_weight_t >()), \
330195 batch_stride, dim_stride, reverse \
331196 ); \
332197 } else { \
333- scanunaligned< weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
334- reinterpret_cast <weight_t *>(gates.data_ptr <torch_weight_t >()), \
335- reinterpret_cast <weight_t *>(tokens.data_ptr <torch_weight_t >()), \
336- reinterpret_cast <weight_t *>(out.data_ptr <torch_weight_t >()), \
198+ scan<UnalignedT , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
199+ reinterpret_cast <const UnalignedT *>(gates.data_ptr <torch_weight_t >()), \
200+ reinterpret_cast <const UnalignedT *>(tokens.data_ptr <torch_weight_t >()), \
201+ reinterpret_cast <UnalignedT *>(out.data_ptr <torch_weight_t >()), \
337202 batch_stride, dim_stride, reverse \
338203 ); \
339204 }
0 commit comments