Skip to content

Commit 4c92aef

Browse files
committed
cleanup aligned / unaligned code paths
1 parent a76cc5a commit 4c92aef

File tree

1 file changed

+75
-210
lines changed

1 file changed

+75
-210
lines changed

accelerated_scan/warp.cuh

Lines changed: 75 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,50 @@
44
#include <ATen/cuda/CUDAContext.h>
55
#include <torch/extension.h>
66

7-
template<typename T>
8-
struct alignas(16) AlignedQuad {
9-
T x;
10-
T y;
11-
T z;
12-
T w;
7+
template<typename weight_t, int N>
8+
class UnalignedTuple {
9+
public:
10+
static constexpr int Size = N;
11+
using Type = weight_t;
1312

14-
__device__ AlignedQuad(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {}
13+
weight_t data[N];
1514

1615
__device__ void reverse() {
17-
// Swap x and w
18-
T temp = x;
19-
x = w;
20-
w = temp;
21-
22-
// Swap y and z
23-
temp = y;
24-
y = z;
25-
z = temp;
16+
#pragma unroll
17+
for (int i = 0; i < N/2; i++) {
18+
weight_t temp = data[i];
19+
data[i] = data[N - (i+1)];
20+
data[N - (i+1)] = temp;
21+
}
2622
}
2723
};
2824

29-
template <typename weight_t, int kNStepsPerThread, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
30-
__global__ void scanunaligned(
31-
const weight_t* gates,
32-
const weight_t* tokens,
33-
weight_t* result,
25+
template<typename T, int N>
26+
class alignas(16) AlignedTuple : public UnalignedTuple<T, N> {
27+
};
28+
29+
template <typename Tuple, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
30+
__global__ void scan(
31+
const Tuple* gates,
32+
const Tuple* tokens,
33+
Tuple* result,
3434
const int batch_stride,
3535
const int dim_stride,
3636
const bool reverse
3737
) {
38+
using weight_t = typename Tuple::Type;
39+
3840
__shared__ weight_t warpLastGate[kNWarpsPerBlock];
3941
__shared__ weight_t warpLastToken[kNWarpsPerBlock];
4042
__shared__ weight_t chunkAccGate, chunkAccToken;
4143

4244
const int seqoffset = blockIdx.x * batch_stride + blockIdx.y * dim_stride;
4345
const int warpId = threadIdx.x / kNThreadsPerWarp;
4446
const int laneId = threadIdx.x % kNThreadsPerWarp;
45-
const int chunklen = blockDim.x * kNStepsPerThread;
47+
const int chunklen = blockDim.x * Tuple::Size;
4648
constexpr int kBlockLast = kNWarpsPerBlock - 1;
4749
constexpr int kWarpLast = kNThreadsPerWarp - 1;
48-
constexpr int kThreadLast = kNStepsPerThread - 1;
50+
constexpr int kThreadLast = Tuple::Size - 1;
4951
const weight_t kEmptyGate = 1.0;
5052
const weight_t kEmptyToken = 0.0;
5153

@@ -54,201 +56,62 @@ __global__ void scanunaligned(
5456
// Scan sequentially in thread registers (level 0).
5557
//
5658

57-
weight_t accGate[kNStepsPerThread];
58-
weight_t accToken[kNStepsPerThread];
59-
6059
for (int chunk = 0; chunk < kNChunksPerSequence; chunk++) {
6160
const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
61+
const int tupleOffset = (offset + (reverse ? chunklen - ((threadIdx.x + 1) * Tuple::Size) : (threadIdx.x * Tuple::Size))) / Tuple::Size;
6262

6363
if (chunk) {
6464
__syncthreads();
6565
}
6666

67+
Tuple loadedGate = gates[tupleOffset];
68+
Tuple loadedToken = tokens[tupleOffset];
69+
if (reverse) {
70+
loadedGate.reverse();
71+
loadedToken.reverse();
72+
}
73+
74+
Tuple accGate;
75+
Tuple accToken;
76+
6777
#pragma unroll
68-
for (int i = 0; i < kNStepsPerThread; ++i) {
69-
const int chunkOffset = reverse ? chunklen - 1 - (threadIdx.x * kNStepsPerThread + i) : (threadIdx.x * kNStepsPerThread + i);
70-
weight_t gate = gates[offset + chunkOffset];
71-
weight_t token = tokens[offset + chunkOffset];
78+
for (int i = 0; i < Tuple::Size; ++i) {
79+
weight_t gate = loadedGate.data[i];
80+
weight_t token = loadedToken.data[i];
7281
if (i == 0) {
7382
if (chunk == 0) {
74-
accGate[0] = threadIdx.x == 0 ? kEmptyGate : gate;
75-
accToken[0] = token;
83+
accGate.data[0] = threadIdx.x == 0 ? kEmptyGate : gate;
84+
accToken.data[0] = token;
7685
} else {
7786
if (threadIdx.x == 0) {
7887
// Add the last element of the previous chunk to the first element of the current chunk.
79-
accGate[0] = chunkAccGate * gate;
80-
accToken[0] = chunkAccToken * gate + token;
88+
accGate.data[0] = chunkAccGate * gate;
89+
accToken.data[0] = chunkAccToken * gate + token;
8190
} else {
82-
accGate[0] = gate;
83-
accToken[0] = token;
91+
accGate.data[0] = gate;
92+
accToken.data[0] = token;
8493
}
8594
}
8695
} else {
87-
accGate[i] = accGate[i - 1] * gate;
88-
accToken[i] = accToken[i - 1] * gate + token;
89-
}
90-
}
91-
92-
//
93-
// Scan threads in a warp using shuffling (level 1).
94-
//
95-
96-
#pragma unroll
97-
for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) {
98-
weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate[kThreadLast], delta);
99-
weight_t prev_token = __shfl_up_sync(0xffffffff, accToken[kThreadLast], delta);
100-
101-
if (laneId >= delta) {
102-
#pragma unroll
103-
for (int i = 0; i < kNStepsPerThread; ++i) {
104-
accToken[i] = prev_token * accGate[i] + accToken[i];
105-
accGate[i] = prev_gate * accGate[i];
106-
}
107-
}
108-
}
109-
110-
__syncwarp();
111-
112-
//
113-
// Store the last element of each warp in shared memory.
114-
//
115-
116-
if (laneId == kWarpLast) {
117-
warpLastGate[warpId] = accGate[kThreadLast];
118-
warpLastToken[warpId] = accToken[kThreadLast];
119-
}
120-
121-
__syncthreads();
122-
123-
//
124-
// Leading warp scans every warp in a block (level 2).
125-
//
126-
127-
if (warpId == 0) {
128-
weight_t warpAccGate, warpAccToken;
129-
warpAccGate = (laneId < kNWarpsPerBlock) ? warpLastGate[laneId] : kEmptyGate;
130-
warpAccToken = (laneId < kNWarpsPerBlock) ? warpLastToken[laneId] : kEmptyToken;
131-
132-
#pragma unroll
133-
for (int delta = 1; delta < warpSize; delta *= 2) {
134-
weight_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta);
135-
weight_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta);
136-
137-
if (laneId >= delta) {
138-
warpAccToken = prev_token * warpAccGate + warpAccToken;
139-
warpAccGate = prev_gate * warpAccGate;
140-
}
141-
}
142-
143-
if (laneId < kNWarpsPerBlock) {
144-
warpLastGate[laneId] = warpAccGate;
145-
warpLastToken[laneId] = warpAccToken;
146-
}
147-
}
148-
149-
__syncthreads();
150-
151-
//
152-
// Add the last element of the previous warp to each element of the current warp (level 0).
153-
// Store to global memory.
154-
//
155-
156-
#pragma unroll
157-
for (int i = 0; i < kNStepsPerThread; ++i) {
158-
const int chunkOffset = reverse ? chunklen - 1 - (threadIdx.x * kNStepsPerThread + i) : (threadIdx.x * kNStepsPerThread + i);
159-
if (warpId > 0) {
160-
accToken[i] = warpLastToken[warpId-1] * accGate[i] + accToken[i];
161-
accGate[i] = warpLastGate[warpId-1] * accGate[i];
162-
}
163-
result[offset + chunkOffset] = accToken[i];
164-
}
165-
166-
if (laneId == kWarpLast && warpId == kBlockLast) {
167-
chunkAccGate = accGate[kThreadLast];
168-
chunkAccToken = accToken[kThreadLast];
169-
}
170-
}
171-
}
172-
173-
template <typename weight_t, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
174-
__global__ void scanaligned(
175-
const weight_t* gates,
176-
const weight_t* tokens,
177-
weight_t* result,
178-
const int batch_stride,
179-
const int dim_stride,
180-
const bool reverse
181-
) {
182-
__shared__ weight_t warpLastGate[kNWarpsPerBlock];
183-
__shared__ weight_t warpLastToken[kNWarpsPerBlock];
184-
__shared__ weight_t chunkAccGate, chunkAccToken;
185-
186-
const int seqoffset = blockIdx.x * batch_stride + blockIdx.y * dim_stride;
187-
const int warpId = threadIdx.x / kNThreadsPerWarp;
188-
const int laneId = threadIdx.x % kNThreadsPerWarp;
189-
const int chunklen = blockDim.x * 4;
190-
constexpr int kBlockLast = kNWarpsPerBlock - 1;
191-
constexpr int kWarpLast = kNThreadsPerWarp - 1;
192-
constexpr int kThreadLast = 3;
193-
const weight_t kEmptyGate = 1.0;
194-
const weight_t kEmptyToken = 0.0;
195-
196-
//
197-
// Read from global memory.
198-
// Scan sequentially in thread registers (level 0).
199-
//
200-
201-
weight_t accGate[4];
202-
weight_t accToken[4];
203-
204-
for (int chunk = 0; chunk < kNChunksPerSequence; chunk++) {
205-
const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
206-
const int quadOffset = (offset + (reverse ? chunklen - ((threadIdx.x + 1) * 4) : (threadIdx.x * 4))) / 4;
207-
208-
if (chunk) {
209-
__syncthreads();
210-
}
211-
212-
AlignedQuad<weight_t> loadedGate = reinterpret_cast<const struct AlignedQuad<weight_t> *>(gates)[quadOffset];
213-
AlignedQuad<weight_t> loadedToken = reinterpret_cast<const struct AlignedQuad<weight_t> *>(tokens)[quadOffset];
214-
if (reverse) {
215-
loadedGate.reverse();
216-
loadedToken.reverse();
217-
}
218-
219-
if (chunk == 0) {
220-
accGate[0] = threadIdx.x == 0 ? kEmptyGate : loadedGate.x;
221-
accToken[0] = loadedToken.x;
222-
} else {
223-
if (threadIdx.x == 0) {
224-
accGate[0] = chunkAccGate * loadedGate.x;
225-
accToken[0] = chunkAccToken * loadedGate.x + loadedToken.x;
226-
} else {
227-
accGate[0] = loadedGate.x;
228-
accToken[0] = loadedToken.x;
96+
accGate.data[i] = accGate.data[i - 1] * gate;
97+
accToken.data[i] = accToken.data[i - 1] * gate + token;
22998
}
23099
}
231-
accGate[1] = accGate[0] * loadedGate.y;
232-
accGate[2] = accGate[1] * loadedGate.z;
233-
accGate[3] = accGate[2] * loadedGate.w;
234-
accToken[1] = loadedToken.y + accToken[0] * loadedGate.y;
235-
accToken[2] = loadedToken.z + accToken[1] * loadedGate.z;
236-
accToken[3] = loadedToken.w + accToken[2] * loadedGate.w;
237100

238101
//
239102
// Scan threads in a warp using shuffling (level 1).
240103
//
241104

242105
#pragma unroll
243106
for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) {
244-
weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate[kThreadLast], delta);
245-
weight_t prev_token = __shfl_up_sync(0xffffffff, accToken[kThreadLast], delta);
107+
weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate.data[kThreadLast], delta);
108+
weight_t prev_token = __shfl_up_sync(0xffffffff, accToken.data[kThreadLast], delta);
246109

247110
if (laneId >= delta) {
248111
#pragma unroll
249-
for (int i = 0; i < 4; ++i) {
250-
accToken[i] = prev_token * accGate[i] + accToken[i];
251-
accGate[i] = prev_gate * accGate[i];
112+
for (int i = 0; i < Tuple::Size; ++i) {
113+
accToken.data[i] = prev_token * accGate.data[i] + accToken.data[i];
114+
accGate.data[i] = prev_gate * accGate.data[i];
252115
}
253116
}
254117
}
@@ -260,8 +123,8 @@ __global__ void scanaligned(
260123
//
261124

262125
if (laneId == kWarpLast) {
263-
warpLastGate[warpId] = accGate[kThreadLast];
264-
warpLastToken[warpId] = accToken[kThreadLast];
126+
warpLastGate[warpId] = accGate.data[kThreadLast];
127+
warpLastToken[warpId] = accToken.data[kThreadLast];
265128
}
266129

267130
__syncthreads();
@@ -300,40 +163,42 @@ __global__ void scanaligned(
300163
//
301164

302165
#pragma unroll
303-
for (int i = 0; i < 4; ++i) {
166+
for (int i = 0; i < Tuple::Size; ++i) {
304167
if (warpId > 0) {
305-
accToken[i] = warpLastToken[warpId-1] * accGate[i] + accToken[i];
306-
accGate[i] = warpLastGate[warpId-1] * accGate[i];
168+
accToken.data[i] = warpLastToken[warpId-1] * accGate.data[i] + accToken.data[i];
169+
accGate.data[i] = warpLastGate[warpId-1] * accGate.data[i];
307170
}
308171
}
309-
310-
AlignedQuad<weight_t> outAccToken(accToken[0], accToken[1], accToken[2], accToken[3]);
311172
if (reverse) {
312-
outAccToken.reverse();
173+
accToken.reverse();
313174
}
314-
reinterpret_cast<struct AlignedQuad<weight_t> *>(result)[quadOffset] = outAccToken;
175+
result[tupleOffset] = accToken;
315176

316177
if (laneId == kWarpLast && warpId == kBlockLast) {
317-
chunkAccGate = accGate[kThreadLast];
318-
chunkAccToken = accToken[kThreadLast];
178+
chunkAccGate = accGate.data[kThreadLast];
179+
chunkAccToken = accToken.data[kThreadLast];
319180
}
320181
}
321182
}
322183

323184
#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse) \
324-
if (kNStepsPerThread == 4 && sizeof(weight_t) <= 4 && ((long)gates.data_ptr()) % 16 == 0 && \
325-
((long)tokens.data_ptr()) % 16 == 0 && ((long)out.data_ptr()) % 16 == 0) { \
326-
scanaligned<weight_t, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
327-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), \
328-
reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), \
329-
reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()), \
185+
using AlignedT = AlignedTuple<weight_t, kNStepsPerThread>; \
186+
using UnalignedT = UnalignedTuple<weight_t, kNStepsPerThread>; \
187+
if (kNStepsPerThread == 4 && \
188+
((long)gates.data_ptr()) % 16 == 0 && \
189+
((long)tokens.data_ptr()) % 16 == 0 && \
190+
((long)out.data_ptr()) % 16 == 0) { \
191+
scan<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
192+
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
193+
reinterpret_cast<const AlignedT *>(tokens.data_ptr<torch_weight_t>()), \
194+
reinterpret_cast<AlignedT *>(out.data_ptr<torch_weight_t>()), \
330195
batch_stride, dim_stride, reverse \
331196
); \
332197
} else { \
333-
scanunaligned<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
334-
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), \
335-
reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), \
336-
reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()), \
198+
scan<UnalignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
199+
reinterpret_cast<const UnalignedT*>(gates.data_ptr<torch_weight_t>()), \
200+
reinterpret_cast<const UnalignedT*>(tokens.data_ptr<torch_weight_t>()), \
201+
reinterpret_cast<UnalignedT *>(out.data_ptr<torch_weight_t>()), \
337202
batch_stride, dim_stride, reverse \
338203
); \
339204
}

0 commit comments

Comments
 (0)