Skip to content

Commit 1be02e2

Browse files
authored
Merge pull request #6 from unixpickle/main
vectorized loads/stores
2 parents a2407ba + 4c92aef commit 1be02e2

File tree

1 file changed

+126
-83
lines changed

1 file changed

+126
-83
lines changed

accelerated_scan/warp.cuh

Lines changed: 126 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,50 @@
44
#include <ATen/cuda/CUDAContext.h>
55
#include <torch/extension.h>
66

7-
template <typename weight_t, int kNStepsPerThread, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
7+
template<typename weight_t, int N>
8+
class UnalignedTuple {
9+
public:
10+
static constexpr int Size = N;
11+
using Type = weight_t;
12+
13+
weight_t data[N];
14+
15+
__device__ void reverse() {
16+
#pragma unroll
17+
for (int i = 0; i < N/2; i++) {
18+
weight_t temp = data[i];
19+
data[i] = data[N - (i+1)];
20+
data[N - (i+1)] = temp;
21+
}
22+
}
23+
};
24+
25+
template<typename T, int N>
26+
class alignas(16) AlignedTuple : public UnalignedTuple<T, N> {
27+
};
28+
29+
template <typename Tuple, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
830
__global__ void scan(
9-
const weight_t* gates,
10-
const weight_t* tokens,
11-
weight_t* result,
31+
const Tuple* gates,
32+
const Tuple* tokens,
33+
Tuple* result,
1234
const int batch_stride,
1335
const int dim_stride,
1436
const bool reverse
1537
) {
38+
using weight_t = typename Tuple::Type;
39+
1640
__shared__ weight_t warpLastGate[kNWarpsPerBlock];
1741
__shared__ weight_t warpLastToken[kNWarpsPerBlock];
1842
__shared__ weight_t chunkAccGate, chunkAccToken;
1943

2044
const int seqoffset = blockIdx.x * batch_stride + blockIdx.y * dim_stride;
2145
const int warpId = threadIdx.x / kNThreadsPerWarp;
2246
const int laneId = threadIdx.x % kNThreadsPerWarp;
23-
const int chunklen = blockDim.x * kNStepsPerThread;
47+
const int chunklen = blockDim.x * Tuple::Size;
2448
constexpr int kBlockLast = kNWarpsPerBlock - 1;
2549
constexpr int kWarpLast = kNThreadsPerWarp - 1;
26-
constexpr int kThreadLast = kNStepsPerThread - 1;
50+
constexpr int kThreadLast = Tuple::Size - 1;
2751
const weight_t kEmptyGate = 1.0;
2852
const weight_t kEmptyToken = 0.0;
2953

@@ -32,38 +56,45 @@ __global__ void scan(
3256
// Scan sequentially in thread registers (level 0).
3357
//
3458

35-
weight_t accGate[kNStepsPerThread];
36-
weight_t accToken[kNStepsPerThread];
37-
3859
for (int chunk = 0; chunk < kNChunksPerSequence; chunk++) {
3960
const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
61+
const int tupleOffset = (offset + (reverse ? chunklen - ((threadIdx.x + 1) * Tuple::Size) : (threadIdx.x * Tuple::Size))) / Tuple::Size;
4062

4163
if (chunk) {
4264
__syncthreads();
4365
}
4466

67+
Tuple loadedGate = gates[tupleOffset];
68+
Tuple loadedToken = tokens[tupleOffset];
69+
if (reverse) {
70+
loadedGate.reverse();
71+
loadedToken.reverse();
72+
}
73+
74+
Tuple accGate;
75+
Tuple accToken;
76+
4577
#pragma unroll
46-
for (int i = 0; i < kNStepsPerThread; ++i) {
47-
const int chunkOffset = reverse ? chunklen - 1 - (threadIdx.x * kNStepsPerThread + i) : (threadIdx.x * kNStepsPerThread + i);
48-
weight_t gate = gates[offset + chunkOffset];
49-
weight_t token = tokens[offset + chunkOffset];
78+
for (int i = 0; i < Tuple::Size; ++i) {
79+
weight_t gate = loadedGate.data[i];
80+
weight_t token = loadedToken.data[i];
5081
if (i == 0) {
5182
if (chunk == 0) {
52-
accGate[0] = threadIdx.x == 0 ? kEmptyGate : gate;
53-
accToken[0] = token;
83+
accGate.data[0] = threadIdx.x == 0 ? kEmptyGate : gate;
84+
accToken.data[0] = token;
5485
} else {
5586
if (threadIdx.x == 0) {
5687
// Add the last element of the previous chunk to the first element of the current chunk.
57-
accGate[0] = chunkAccGate * gate;
58-
accToken[0] = chunkAccToken * gate + token;
88+
accGate.data[0] = chunkAccGate * gate;
89+
accToken.data[0] = chunkAccToken * gate + token;
5990
} else {
60-
accGate[0] = gate;
61-
accToken[0] = token;
91+
accGate.data[0] = gate;
92+
accToken.data[0] = token;
6293
}
6394
}
6495
} else {
65-
accGate[i] = accGate[i - 1] * gate;
66-
accToken[i] = accToken[i - 1] * gate + token;
96+
accGate.data[i] = accGate.data[i - 1] * gate;
97+
accToken.data[i] = accToken.data[i - 1] * gate + token;
6798
}
6899
}
69100

@@ -73,14 +104,14 @@ __global__ void scan(
73104

74105
#pragma unroll
75106
for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) {
76-
weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate[kThreadLast], delta);
77-
weight_t prev_token = __shfl_up_sync(0xffffffff, accToken[kThreadLast], delta);
107+
weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate.data[kThreadLast], delta);
108+
weight_t prev_token = __shfl_up_sync(0xffffffff, accToken.data[kThreadLast], delta);
78109

79110
if (laneId >= delta) {
80111
#pragma unroll
81-
for (int i = 0; i < kNStepsPerThread; ++i) {
82-
accToken[i] = prev_token * accGate[i] + accToken[i];
83-
accGate[i] = prev_gate * accGate[i];
112+
for (int i = 0; i < Tuple::Size; ++i) {
113+
accToken.data[i] = prev_token * accGate.data[i] + accToken.data[i];
114+
accGate.data[i] = prev_gate * accGate.data[i];
84115
}
85116
}
86117
}
@@ -92,8 +123,8 @@ __global__ void scan(
92123
//
93124

94125
if (laneId == kWarpLast) {
95-
warpLastGate[warpId] = accGate[kThreadLast];
96-
warpLastToken[warpId] = accToken[kThreadLast];
126+
warpLastGate[warpId] = accGate.data[kThreadLast];
127+
warpLastToken[warpId] = accToken.data[kThreadLast];
97128
}
98129

99130
__syncthreads();
@@ -132,22 +163,46 @@ __global__ void scan(
132163
//
133164

134165
#pragma unroll
135-
for (int i = 0; i < kNStepsPerThread; ++i) {
136-
const int chunkOffset = reverse ? chunklen - 1 - (threadIdx.x * kNStepsPerThread + i) : (threadIdx.x * kNStepsPerThread + i);
166+
for (int i = 0; i < Tuple::Size; ++i) {
137167
if (warpId > 0) {
138-
accToken[i] = warpLastToken[warpId-1] * accGate[i] + accToken[i];
139-
accGate[i] = warpLastGate[warpId-1] * accGate[i];
168+
accToken.data[i] = warpLastToken[warpId-1] * accGate.data[i] + accToken.data[i];
169+
accGate.data[i] = warpLastGate[warpId-1] * accGate.data[i];
140170
}
141-
result[offset + chunkOffset] = accToken[i];
142171
}
172+
if (reverse) {
173+
accToken.reverse();
174+
}
175+
result[tupleOffset] = accToken;
143176

144177
if (laneId == kWarpLast && warpId == kBlockLast) {
145-
chunkAccGate = accGate[kThreadLast];
146-
chunkAccToken = accToken[kThreadLast];
178+
chunkAccGate = accGate.data[kThreadLast];
179+
chunkAccToken = accToken.data[kThreadLast];
147180
}
148181
}
149182
}
150183

184+
#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse) \
185+
using AlignedT = AlignedTuple<weight_t, kNStepsPerThread>; \
186+
using UnalignedT = UnalignedTuple<weight_t, kNStepsPerThread>; \
187+
if (kNStepsPerThread == 4 && \
188+
((long)gates.data_ptr()) % 16 == 0 && \
189+
((long)tokens.data_ptr()) % 16 == 0 && \
190+
((long)out.data_ptr()) % 16 == 0) { \
191+
scan<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
192+
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
193+
reinterpret_cast<const AlignedT *>(tokens.data_ptr<torch_weight_t>()), \
194+
reinterpret_cast<AlignedT *>(out.data_ptr<torch_weight_t>()), \
195+
batch_stride, dim_stride, reverse \
196+
); \
197+
} else { \
198+
scan<UnalignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
199+
reinterpret_cast<const UnalignedT*>(gates.data_ptr<torch_weight_t>()), \
200+
reinterpret_cast<const UnalignedT*>(tokens.data_ptr<torch_weight_t>()), \
201+
reinterpret_cast<UnalignedT *>(out.data_ptr<torch_weight_t>()), \
202+
batch_stride, dim_stride, reverse \
203+
); \
204+
}
205+
151206
template <typename weight_t, typename torch_weight_t>
152207
void
153208
warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
@@ -171,109 +226,97 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
171226
constexpr int kNWarpsPerBlock = 1;
172227
int kNThreads = seqlen / kNStepsPerThread;
173228
constexpr int kNChunksPerSequence = 1;
174-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
175-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
176-
batch_stride, dim_stride, reverse
177-
);
229+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
230+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
231+
batch_stride, dim_stride, reverse);
178232
} else if (seqlen == 64) {
179233
constexpr int kNStepsPerThread = 2;
180234
constexpr int kNWarpsPerBlock = 1;
181235
constexpr int kNChunksPerSequence = 1;
182236
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
183-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
184-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
185-
batch_stride, dim_stride, reverse
186-
);
237+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
238+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
239+
batch_stride, dim_stride, reverse);
187240
} else if (seqlen == 128) {
188241
constexpr int kNStepsPerThread = 1;
189242
constexpr int kNWarpsPerBlock = 4;
190243
int kNThreads = seqlen / kNStepsPerThread;
191244
constexpr int kNChunksPerSequence = 1;
192-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
193-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
194-
batch_stride, dim_stride, reverse
195-
);
245+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
246+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
247+
batch_stride, dim_stride, reverse);
196248
} else if (seqlen == 256) {
197249
constexpr int kNStepsPerThread = 1;
198250
constexpr int kNWarpsPerBlock = 8;
199251
int kNThreads = seqlen / kNStepsPerThread;
200252
constexpr int kNChunksPerSequence = 1;
201-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
202-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
203-
batch_stride, dim_stride, reverse
204-
);
253+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
254+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
255+
batch_stride, dim_stride, reverse);
205256
} else if (seqlen == 512) {
206257
constexpr int kNStepsPerThread = 1;
207258
constexpr int kNWarpsPerBlock = 16;
208259
int kNThreads = seqlen / kNStepsPerThread;
209260
constexpr int kNChunksPerSequence = 1;
210-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
211-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
212-
batch_stride, dim_stride, reverse
213-
);
261+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
262+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
263+
batch_stride, dim_stride, reverse);
214264
} else if (seqlen == 1024) {
215265
constexpr int kNStepsPerThread = 2;
216266
constexpr int kNWarpsPerBlock = 16;
217267
int kNThreads = seqlen / kNStepsPerThread;
218268
constexpr int kNChunksPerSequence = 1;
219-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
220-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
221-
batch_stride, dim_stride, reverse
222-
);
269+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
270+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
271+
batch_stride, dim_stride, reverse);
223272
} else if (seqlen == 2048) {
224273
constexpr int kNStepsPerThread = 2;
225274
constexpr int kNWarpsPerBlock = 32;
226275
int kNThreads = seqlen / kNStepsPerThread;
227276
constexpr int kNChunksPerSequence = 1;
228-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
229-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
230-
batch_stride, dim_stride, reverse
231-
);
277+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
278+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
279+
batch_stride, dim_stride, reverse);
232280
} else if (seqlen == 4096) {
233281
constexpr int kNStepsPerThread = 4;
234282
constexpr int kNWarpsPerBlock = 32;
235283
int kNThreads = seqlen / kNStepsPerThread;
236284
constexpr int kNChunksPerSequence = 1;
237-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
238-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
239-
batch_stride, dim_stride, reverse
240-
);
285+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
286+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
287+
batch_stride, dim_stride, reverse);
241288
} else if (seqlen == 8192) {
242289
constexpr int kNStepsPerThread = 4;
243290
constexpr int kNWarpsPerBlock = 32;
244291
constexpr int kNChunksPerSequence = 2;
245292
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
246-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
247-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
248-
batch_stride, dim_stride, reverse
249-
);
293+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
294+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
295+
batch_stride, dim_stride, reverse);
250296
} else if (seqlen == 16384) {
251297
constexpr int kNStepsPerThread = 4;
252298
constexpr int kNWarpsPerBlock = 32;
253299
constexpr int kNChunksPerSequence = 4;
254300
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
255-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
256-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
257-
batch_stride, dim_stride, reverse
258-
);
301+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
302+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
303+
batch_stride, dim_stride, reverse);
259304
} else if (seqlen == 32768) {
260305
constexpr int kNStepsPerThread = 4;
261306
constexpr int kNWarpsPerBlock = 32;
262307
constexpr int kNChunksPerSequence = 8;
263308
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
264-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
265-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
266-
batch_stride, dim_stride, reverse
267-
);
309+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
310+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
311+
batch_stride, dim_stride, reverse);
268312
} else if (seqlen == 65536) {
269313
constexpr int kNStepsPerThread = 4;
270314
constexpr int kNWarpsPerBlock = 32;
271315
constexpr int kNChunksPerSequence = 16;
272316
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
273-
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
274-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
275-
batch_stride, dim_stride, reverse
276-
);
317+
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
318+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
319+
batch_stride, dim_stride, reverse);
277320
} else {
278321
TORCH_CHECK(false && "seqlen must be a power of 2, >= 32, <= 65536");
279322
}

0 commit comments

Comments
 (0)