Skip to content

Commit a76cc5a

Browse files
committed
vectorized loads/stores
1 parent 076c2f7 commit a76cc5a

File tree

1 file changed

+227
-49
lines changed

1 file changed

+227
-49
lines changed

accelerated_scan/warp.cuh

Lines changed: 227 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,30 @@
44
#include <ATen/cuda/CUDAContext.h>
55
#include <torch/extension.h>
66

7+
template<typename T>
8+
struct alignas(16) AlignedQuad {
9+
T x;
10+
T y;
11+
T z;
12+
T w;
13+
14+
__device__ AlignedQuad(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {}
15+
16+
__device__ void reverse() {
17+
// Swap x and w
18+
T temp = x;
19+
x = w;
20+
w = temp;
21+
22+
// Swap y and z
23+
temp = y;
24+
y = z;
25+
z = temp;
26+
}
27+
};
28+
729
template <typename weight_t, int kNStepsPerThread, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
8-
__global__ void scan(
30+
__global__ void scanunaligned(
931
const weight_t* gates,
1032
const weight_t* tokens,
1133
weight_t* result,
@@ -148,6 +170,174 @@ __global__ void scan(
148170
}
149171
}
150172

173+
template <typename weight_t, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
174+
__global__ void scanaligned(
175+
const weight_t* gates,
176+
const weight_t* tokens,
177+
weight_t* result,
178+
const int batch_stride,
179+
const int dim_stride,
180+
const bool reverse
181+
) {
182+
__shared__ weight_t warpLastGate[kNWarpsPerBlock];
183+
__shared__ weight_t warpLastToken[kNWarpsPerBlock];
184+
__shared__ weight_t chunkAccGate, chunkAccToken;
185+
186+
const int seqoffset = blockIdx.x * batch_stride + blockIdx.y * dim_stride;
187+
const int warpId = threadIdx.x / kNThreadsPerWarp;
188+
const int laneId = threadIdx.x % kNThreadsPerWarp;
189+
const int chunklen = blockDim.x * 4;
190+
constexpr int kBlockLast = kNWarpsPerBlock - 1;
191+
constexpr int kWarpLast = kNThreadsPerWarp - 1;
192+
constexpr int kThreadLast = 3;
193+
const weight_t kEmptyGate = 1.0;
194+
const weight_t kEmptyToken = 0.0;
195+
196+
//
197+
// Read from global memory.
198+
// Scan sequentially in thread registers (level 0).
199+
//
200+
201+
weight_t accGate[4];
202+
weight_t accToken[4];
203+
204+
for (int chunk = 0; chunk < kNChunksPerSequence; chunk++) {
205+
const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
206+
const int quadOffset = (offset + (reverse ? chunklen - ((threadIdx.x + 1) * 4) : (threadIdx.x * 4))) / 4;
207+
208+
if (chunk) {
209+
__syncthreads();
210+
}
211+
212+
AlignedQuad<weight_t> loadedGate = reinterpret_cast<const struct AlignedQuad<weight_t> *>(gates)[quadOffset];
213+
AlignedQuad<weight_t> loadedToken = reinterpret_cast<const struct AlignedQuad<weight_t> *>(tokens)[quadOffset];
214+
if (reverse) {
215+
loadedGate.reverse();
216+
loadedToken.reverse();
217+
}
218+
219+
if (chunk == 0) {
220+
accGate[0] = threadIdx.x == 0 ? kEmptyGate : loadedGate.x;
221+
accToken[0] = loadedToken.x;
222+
} else {
223+
if (threadIdx.x == 0) {
224+
accGate[0] = chunkAccGate * loadedGate.x;
225+
accToken[0] = chunkAccToken * loadedGate.x + loadedToken.x;
226+
} else {
227+
accGate[0] = loadedGate.x;
228+
accToken[0] = loadedToken.x;
229+
}
230+
}
231+
accGate[1] = accGate[0] * loadedGate.y;
232+
accGate[2] = accGate[1] * loadedGate.z;
233+
accGate[3] = accGate[2] * loadedGate.w;
234+
accToken[1] = loadedToken.y + accToken[0] * loadedGate.y;
235+
accToken[2] = loadedToken.z + accToken[1] * loadedGate.z;
236+
accToken[3] = loadedToken.w + accToken[2] * loadedGate.w;
237+
238+
//
239+
// Scan threads in a warp using shuffling (level 1).
240+
//
241+
242+
#pragma unroll
243+
for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) {
244+
weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate[kThreadLast], delta);
245+
weight_t prev_token = __shfl_up_sync(0xffffffff, accToken[kThreadLast], delta);
246+
247+
if (laneId >= delta) {
248+
#pragma unroll
249+
for (int i = 0; i < 4; ++i) {
250+
accToken[i] = prev_token * accGate[i] + accToken[i];
251+
accGate[i] = prev_gate * accGate[i];
252+
}
253+
}
254+
}
255+
256+
__syncwarp();
257+
258+
//
259+
// Store the last element of each warp in shared memory.
260+
//
261+
262+
if (laneId == kWarpLast) {
263+
warpLastGate[warpId] = accGate[kThreadLast];
264+
warpLastToken[warpId] = accToken[kThreadLast];
265+
}
266+
267+
__syncthreads();
268+
269+
//
270+
// Leading warp scans every warp in a block (level 2).
271+
//
272+
273+
if (warpId == 0) {
274+
weight_t warpAccGate, warpAccToken;
275+
warpAccGate = (laneId < kNWarpsPerBlock) ? warpLastGate[laneId] : kEmptyGate;
276+
warpAccToken = (laneId < kNWarpsPerBlock) ? warpLastToken[laneId] : kEmptyToken;
277+
278+
#pragma unroll
279+
for (int delta = 1; delta < warpSize; delta *= 2) {
280+
weight_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta);
281+
weight_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta);
282+
283+
if (laneId >= delta) {
284+
warpAccToken = prev_token * warpAccGate + warpAccToken;
285+
warpAccGate = prev_gate * warpAccGate;
286+
}
287+
}
288+
289+
if (laneId < kNWarpsPerBlock) {
290+
warpLastGate[laneId] = warpAccGate;
291+
warpLastToken[laneId] = warpAccToken;
292+
}
293+
}
294+
295+
__syncthreads();
296+
297+
//
298+
// Add the last element of the previous warp to each element of the current warp (level 0).
299+
// Store to global memory.
300+
//
301+
302+
#pragma unroll
303+
for (int i = 0; i < 4; ++i) {
304+
if (warpId > 0) {
305+
accToken[i] = warpLastToken[warpId-1] * accGate[i] + accToken[i];
306+
accGate[i] = warpLastGate[warpId-1] * accGate[i];
307+
}
308+
}
309+
310+
AlignedQuad<weight_t> outAccToken(accToken[0], accToken[1], accToken[2], accToken[3]);
311+
if (reverse) {
312+
outAccToken.reverse();
313+
}
314+
reinterpret_cast<struct AlignedQuad<weight_t> *>(result)[quadOffset] = outAccToken;
315+
316+
if (laneId == kWarpLast && warpId == kBlockLast) {
317+
chunkAccGate = accGate[kThreadLast];
318+
chunkAccToken = accToken[kThreadLast];
319+
}
320+
}
321+
}
322+
323+
#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse) \
324+
if (kNStepsPerThread == 4 && sizeof(weight_t) <= 4 && ((long)gates.data_ptr()) % 16 == 0 && \
325+
((long)tokens.data_ptr()) % 16 == 0 && ((long)out.data_ptr()) % 16 == 0) { \
326+
scanaligned<weight_t, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
327+
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), \
328+
reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), \
329+
reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()), \
330+
batch_stride, dim_stride, reverse \
331+
); \
332+
} else { \
333+
scanunaligned<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
334+
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), \
335+
reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), \
336+
reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()), \
337+
batch_stride, dim_stride, reverse \
338+
); \
339+
}
340+
151341
template <typename weight_t, typename torch_weight_t>
152342
void
153343
warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
@@ -171,109 +361,97 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
171361
constexpr int kNWarpsPerBlock = 1;
172362
int kNThreads = seqlen / kNStepsPerThread;
173363
constexpr int kNChunksPerSequence = 1;
174-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
175-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
176-
batch_stride, dim_stride, reverse
177-
);
364+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
365+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
366+
batch_stride, dim_stride, reverse);
178367
} else if (seqlen == 64) {
179368
constexpr int kNStepsPerThread = 2;
180369
constexpr int kNWarpsPerBlock = 1;
181370
constexpr int kNChunksPerSequence = 1;
182371
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
183-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
184-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
185-
batch_stride, dim_stride, reverse
186-
);
372+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
373+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
374+
batch_stride, dim_stride, reverse);
187375
} else if (seqlen == 128) {
188376
constexpr int kNStepsPerThread = 1;
189377
constexpr int kNWarpsPerBlock = 4;
190378
int kNThreads = seqlen / kNStepsPerThread;
191379
constexpr int kNChunksPerSequence = 1;
192-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
193-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
194-
batch_stride, dim_stride, reverse
195-
);
380+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
381+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
382+
batch_stride, dim_stride, reverse);
196383
} else if (seqlen == 256) {
197384
constexpr int kNStepsPerThread = 1;
198385
constexpr int kNWarpsPerBlock = 8;
199386
int kNThreads = seqlen / kNStepsPerThread;
200387
constexpr int kNChunksPerSequence = 1;
201-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
202-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
203-
batch_stride, dim_stride, reverse
204-
);
388+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
389+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
390+
batch_stride, dim_stride, reverse);
205391
} else if (seqlen == 512) {
206392
constexpr int kNStepsPerThread = 1;
207393
constexpr int kNWarpsPerBlock = 16;
208394
int kNThreads = seqlen / kNStepsPerThread;
209395
constexpr int kNChunksPerSequence = 1;
210-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
211-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
212-
batch_stride, dim_stride, reverse
213-
);
396+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
397+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
398+
batch_stride, dim_stride, reverse);
214399
} else if (seqlen == 1024) {
215400
constexpr int kNStepsPerThread = 2;
216401
constexpr int kNWarpsPerBlock = 16;
217402
int kNThreads = seqlen / kNStepsPerThread;
218403
constexpr int kNChunksPerSequence = 1;
219-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
220-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
221-
batch_stride, dim_stride, reverse
222-
);
404+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
405+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
406+
batch_stride, dim_stride, reverse);
223407
} else if (seqlen == 2048) {
224408
constexpr int kNStepsPerThread = 2;
225409
constexpr int kNWarpsPerBlock = 32;
226410
int kNThreads = seqlen / kNStepsPerThread;
227411
constexpr int kNChunksPerSequence = 1;
228-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
229-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
230-
batch_stride, dim_stride, reverse
231-
);
412+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
413+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
414+
batch_stride, dim_stride, reverse);
232415
} else if (seqlen == 4096) {
233416
constexpr int kNStepsPerThread = 4;
234417
constexpr int kNWarpsPerBlock = 32;
235418
int kNThreads = seqlen / kNStepsPerThread;
236419
constexpr int kNChunksPerSequence = 1;
237-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
238-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
239-
batch_stride, dim_stride, reverse
240-
);
420+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
421+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
422+
batch_stride, dim_stride, reverse);
241423
} else if (seqlen == 8192) {
242424
constexpr int kNStepsPerThread = 4;
243425
constexpr int kNWarpsPerBlock = 32;
244426
constexpr int kNChunksPerSequence = 2;
245427
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
246-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
247-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
248-
batch_stride, dim_stride, reverse
249-
);
428+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
429+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
430+
batch_stride, dim_stride, reverse);
250431
} else if (seqlen == 16384) {
251432
constexpr int kNStepsPerThread = 4;
252433
constexpr int kNWarpsPerBlock = 32;
253434
constexpr int kNChunksPerSequence = 4;
254435
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
255-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
256-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
257-
batch_stride, dim_stride, reverse
258-
);
436+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
437+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
438+
batch_stride, dim_stride, reverse);
259439
} else if (seqlen == 32768) {
260440
constexpr int kNStepsPerThread = 4;
261441
constexpr int kNWarpsPerBlock = 32;
262442
constexpr int kNChunksPerSequence = 8;
263443
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
264-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
265-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
266-
batch_stride, dim_stride, reverse
267-
);
444+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
445+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
446+
batch_stride, dim_stride, reverse);
268447
} else if (seqlen == 65536) {
269448
constexpr int kNStepsPerThread = 4;
270449
constexpr int kNWarpsPerBlock = 32;
271450
constexpr int kNChunksPerSequence = 16;
272451
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
273-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
274-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
275-
batch_stride, dim_stride, reverse
276-
);
452+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
453+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
454+
batch_stride, dim_stride, reverse);
277455
} else {
278456
TORCH_CHECK(false && "seqlen must be a power of 2, >= 32, <= 65536");
279457
}

0 commit comments

Comments
 (0)