Skip to content

Commit 0e12f07

Browse files
authored
Merge pull request #7 from unixpickle/fused-backward
Fused backward pass kernel
2 parents 5cb9403 + f47c2f1 commit 0e12f07

File tree

2 files changed

+129
-51
lines changed

2 files changed

+129
-51
lines changed

accelerated_scan/warp.cuh

Lines changed: 123 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <ATen/cuda/CUDAContext.h>
55
#include <torch/extension.h>
66

7+
#define CHECK_STRIDE(x) TORCH_CHECK(x.stride(-1) == 1 || x.size(-1) == 1);
8+
79
template<typename weight_t, int N>
810
class UnalignedTuple {
911
public:
@@ -26,11 +28,33 @@ template<typename T, int N>
2628
class alignas(16) AlignedTuple : public UnalignedTuple<T, N> {
2729
};
2830

29-
template <typename Tuple, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
31+
template <typename Tuple, int offset>
32+
__device__ Tuple load_shifted_tuple(const Tuple* ptr, int index, int minIdx, int maxIdx) {
33+
using weight_t = typename Tuple::Type;
34+
35+
const weight_t* rawPtr = reinterpret_cast<const weight_t *>(ptr);
36+
Tuple x;
37+
for (int i = 0; i < Tuple::Size; i++) {
38+
const int idx = index * Tuple::Size + i + offset;
39+
if (idx >= minIdx * Tuple::Size && idx < maxIdx * Tuple::Size) {
40+
x.data[i] = rawPtr[idx];
41+
} else {
42+
x.data[i] = 0.0;
43+
}
44+
}
45+
46+
return x;
47+
}
48+
49+
template <typename Tuple, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence, bool backward>
3050
__global__ void scan(
3151
const Tuple* gates,
3252
const Tuple* tokens,
3353
Tuple* result,
54+
// Only passed if backward is True.
55+
const Tuple* output,
56+
Tuple* gateGradOut,
57+
// Shape information
3458
const int batch_stride,
3559
const int dim_stride,
3660
const bool reverse
@@ -51,6 +75,10 @@ __global__ void scan(
5175
const weight_t kEmptyGate = 1.0;
5276
const weight_t kEmptyToken = 0.0;
5377

78+
// Limits for loading shifted tuples during backward pass.
79+
const int minIdx = seqoffset / Tuple::Size;
80+
const int maxIdx = minIdx + blockDim.x * kNChunksPerSequence;
81+
5482
//
5583
// Read from global memory.
5684
// Scan sequentially in thread registers (level 0).
@@ -64,7 +92,12 @@ __global__ void scan(
6492
__syncthreads();
6593
}
6694

67-
Tuple loadedGate = gates[tupleOffset];
95+
Tuple loadedGate;
96+
if (backward) {
97+
loadedGate = load_shifted_tuple<Tuple, 1>(gates, tupleOffset, minIdx, maxIdx);
98+
} else {
99+
loadedGate = gates[tupleOffset];
100+
}
68101
Tuple loadedToken = tokens[tupleOffset];
69102
if (reverse) {
70103
loadedGate.reverse();
@@ -174,43 +207,68 @@ __global__ void scan(
174207
}
175208
result[tupleOffset] = accToken;
176209

210+
if (backward) {
211+
Tuple gateGrad = load_shifted_tuple<Tuple, -1>(output, tupleOffset, minIdx, maxIdx);
212+
for (int i = 0; i < Tuple::Size; i++) {
213+
gateGrad.data[i] = gateGrad.data[i] * accToken.data[i];
214+
}
215+
gateGradOut[tupleOffset] = gateGrad;
216+
}
217+
177218
if (laneId == kWarpLast && warpId == kBlockLast) {
178219
chunkAccGate = accGate.data[kThreadLast];
179220
chunkAccToken = accToken.data[kThreadLast];
180221
}
181222
}
182223
}
183224

184-
#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse) \
225+
#define DISPATCH_SCAN_INNER(TupleT, backward, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \
226+
scan<TupleT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, backward><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
227+
reinterpret_cast<const TupleT *>(gates.data_ptr<torch_weight_t>()), \
228+
reinterpret_cast<const TupleT *>(tokens.data_ptr<torch_weight_t>()), \
229+
reinterpret_cast<TupleT *>(out.data_ptr<torch_weight_t>()), \
230+
reinterpret_cast<const TupleT *>(output), \
231+
reinterpret_cast<TupleT *>(gateGradOut), \
232+
batch_stride, dim_stride, reverse \
233+
);
234+
235+
#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \
185236
using AlignedT = AlignedTuple<weight_t, kNStepsPerThread>; \
186237
using UnalignedT = UnalignedTuple<weight_t, kNStepsPerThread>; \
187238
if (kNStepsPerThread == 4 && \
188239
((long)gates.data_ptr()) % 16 == 0 && \
189240
((long)tokens.data_ptr()) % 16 == 0 && \
190-
((long)out.data_ptr()) % 16 == 0) { \
191-
scan<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
192-
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
193-
reinterpret_cast<const AlignedT *>(tokens.data_ptr<torch_weight_t>()), \
194-
reinterpret_cast<AlignedT *>(out.data_ptr<torch_weight_t>()), \
195-
batch_stride, dim_stride, reverse \
196-
); \
241+
((long)out.data_ptr()) % 16 == 0 && \
242+
((long)output) % 16 == 0 && \
243+
((long)gateGradOut) % 16 == 0) { \
244+
if (output) { \
245+
DISPATCH_SCAN_INNER(AlignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
246+
} else { \
247+
DISPATCH_SCAN_INNER(AlignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
248+
} \
197249
} else { \
198-
scan<UnalignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
199-
reinterpret_cast<const UnalignedT*>(gates.data_ptr<torch_weight_t>()), \
200-
reinterpret_cast<const UnalignedT*>(tokens.data_ptr<torch_weight_t>()), \
201-
reinterpret_cast<UnalignedT *>(out.data_ptr<torch_weight_t>()), \
202-
batch_stride, dim_stride, reverse \
203-
); \
250+
if (output) { \
251+
DISPATCH_SCAN_INNER(UnalignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
252+
} else { \
253+
DISPATCH_SCAN_INNER(UnalignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
254+
} \
204255
}
205256

206257
template <typename weight_t, typename torch_weight_t>
207258
void
208-
warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
259+
warpscan(
260+
const at::Tensor &gates,
261+
const at::Tensor &tokens,
262+
const at::Tensor &out,
263+
const void *output,
264+
void *gateGradOut,
265+
const bool reverse
266+
) {
209267
const auto strides = tokens.strides();
210268
const int batch_stride = strides[0];
211269
const int dim_stride = strides[1];
212-
TORCH_CHECK(tokens.stride(-1) == 1 || tokens.size(-1) == 1);
213-
TORCH_CHECK(gates.stride(-1) == 1 || gates.size(-1) == 1);
270+
CHECK_STRIDE(tokens);
271+
CHECK_STRIDE(gates);
214272

215273
const auto sizes = tokens.sizes();
216274
const int batch_size = sizes[0];
@@ -227,119 +285,140 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
227285
int kNThreads = seqlen / kNStepsPerThread;
228286
constexpr int kNChunksPerSequence = 1;
229287
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
230-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
288+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
231289
batch_stride, dim_stride, reverse);
232290
} else if (seqlen == 64) {
233291
constexpr int kNStepsPerThread = 2;
234292
constexpr int kNWarpsPerBlock = 1;
235293
constexpr int kNChunksPerSequence = 1;
236294
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
237295
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
238-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
296+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
239297
batch_stride, dim_stride, reverse);
240298
} else if (seqlen == 128) {
241299
constexpr int kNStepsPerThread = 1;
242300
constexpr int kNWarpsPerBlock = 4;
243301
int kNThreads = seqlen / kNStepsPerThread;
244302
constexpr int kNChunksPerSequence = 1;
245303
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
246-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
304+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
247305
batch_stride, dim_stride, reverse);
248306
} else if (seqlen == 256) {
249307
constexpr int kNStepsPerThread = 1;
250308
constexpr int kNWarpsPerBlock = 8;
251309
int kNThreads = seqlen / kNStepsPerThread;
252310
constexpr int kNChunksPerSequence = 1;
253311
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
254-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
312+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
255313
batch_stride, dim_stride, reverse);
256314
} else if (seqlen == 512) {
257315
constexpr int kNStepsPerThread = 1;
258316
constexpr int kNWarpsPerBlock = 16;
259317
int kNThreads = seqlen / kNStepsPerThread;
260318
constexpr int kNChunksPerSequence = 1;
261319
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
262-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
320+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
263321
batch_stride, dim_stride, reverse);
264322
} else if (seqlen == 1024) {
265323
constexpr int kNStepsPerThread = 2;
266324
constexpr int kNWarpsPerBlock = 16;
267325
int kNThreads = seqlen / kNStepsPerThread;
268326
constexpr int kNChunksPerSequence = 1;
269327
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
270-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
328+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
271329
batch_stride, dim_stride, reverse);
272330
} else if (seqlen == 2048) {
273331
constexpr int kNStepsPerThread = 2;
274332
constexpr int kNWarpsPerBlock = 32;
275333
int kNThreads = seqlen / kNStepsPerThread;
276334
constexpr int kNChunksPerSequence = 1;
277335
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
278-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
336+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
279337
batch_stride, dim_stride, reverse);
280338
} else if (seqlen == 4096) {
281339
constexpr int kNStepsPerThread = 4;
282340
constexpr int kNWarpsPerBlock = 32;
283341
int kNThreads = seqlen / kNStepsPerThread;
284342
constexpr int kNChunksPerSequence = 1;
285343
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
286-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
344+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
287345
batch_stride, dim_stride, reverse);
288346
} else if (seqlen == 8192) {
289347
constexpr int kNStepsPerThread = 4;
290348
constexpr int kNWarpsPerBlock = 32;
291349
constexpr int kNChunksPerSequence = 2;
292350
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
293351
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
294-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
352+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
295353
batch_stride, dim_stride, reverse);
296354
} else if (seqlen == 16384) {
297355
constexpr int kNStepsPerThread = 4;
298356
constexpr int kNWarpsPerBlock = 32;
299357
constexpr int kNChunksPerSequence = 4;
300358
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
301359
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
302-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
360+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
303361
batch_stride, dim_stride, reverse);
304362
} else if (seqlen == 32768) {
305363
constexpr int kNStepsPerThread = 4;
306364
constexpr int kNWarpsPerBlock = 32;
307365
constexpr int kNChunksPerSequence = 8;
308366
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
309367
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
310-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
368+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
311369
batch_stride, dim_stride, reverse);
312370
} else if (seqlen == 65536) {
313371
constexpr int kNStepsPerThread = 4;
314372
constexpr int kNWarpsPerBlock = 32;
315373
constexpr int kNChunksPerSequence = 16;
316374
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
317375
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
318-
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
376+
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
319377
batch_stride, dim_stride, reverse);
320378
} else {
321379
TORCH_CHECK(false && "seqlen must be a power of 2, >= 32, <= 65536");
322380
}
323381
}
324382

383+
#define DISPATCH_WARPSCAN(gates, ...) \
384+
if (gates.scalar_type() == at::ScalarType::BFloat16) { \
385+
warpscan<__nv_bfloat16, at::BFloat16>(gates, __VA_ARGS__); \
386+
} else if (gates.scalar_type() == at::ScalarType::Half) { \
387+
warpscan<__half, at::Half>(gates, __VA_ARGS__); \
388+
} else if (gates.scalar_type() == at::ScalarType::Float) { \
389+
warpscan<float, float>(gates, __VA_ARGS__); \
390+
} else { \
391+
TORCH_CHECK(false && "Unsupported tensor dtype: expecting bfloat16, float16 or float32"); \
392+
}
393+
325394
at::Tensor
326395
warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
327396
TORCH_CHECK(tokens.is_cuda());
328397
TORCH_CHECK(gates.is_cuda());
329398
TORCH_CHECK(tokens.is_contiguous());
330399
TORCH_CHECK(gates.is_contiguous());
400+
TORCH_CHECK(tokens.scalar_type() == gates.scalar_type());
401+
TORCH_CHECK(tokens.scalar_type() == out.scalar_type());
331402

332-
if (tokens.scalar_type() == at::ScalarType::BFloat16) {
333-
TORCH_CHECK(gates.scalar_type() == at::ScalarType::BFloat16);
334-
warpscan<__nv_bfloat16, at::BFloat16>(gates, tokens, out, reverse);
335-
} else if (tokens.scalar_type() == at::ScalarType::Half) {
336-
TORCH_CHECK(gates.scalar_type() == at::ScalarType::Half);
337-
warpscan<__half, at::Half>(gates, tokens, out, reverse);
338-
} else if (tokens.scalar_type() == at::ScalarType::Float) {
339-
TORCH_CHECK(gates.scalar_type() == at::ScalarType::Float);
340-
warpscan<float, float>(gates, tokens, out, reverse);
341-
} else {
342-
TORCH_CHECK(false && "Unsupported tensor dtype: expecting bfloat16, float16 or float32");
343-
}
403+
DISPATCH_WARPSCAN(gates, tokens, out, nullptr, nullptr, reverse);
344404
return out;
345-
}
405+
}
406+
407+
void
408+
warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& tokenGradOut) {
409+
TORCH_CHECK(gates.is_cuda());
410+
TORCH_CHECK(output.is_cuda());
411+
TORCH_CHECK(outGrad.is_cuda());
412+
TORCH_CHECK(gateGradOut.is_contiguous());
413+
TORCH_CHECK(tokenGradOut.is_contiguous());
414+
TORCH_CHECK(gates.scalar_type() == output.scalar_type());
415+
TORCH_CHECK(gates.scalar_type() == outGrad.scalar_type());
416+
TORCH_CHECK(gates.scalar_type() == gateGradOut.scalar_type());
417+
TORCH_CHECK(gates.scalar_type() == tokenGradOut.scalar_type());
418+
TORCH_CHECK(gates.sizes() == output.sizes());
419+
TORCH_CHECK(gates.sizes() == outGrad.sizes());
420+
TORCH_CHECK(gates.sizes() == gateGradOut.sizes());
421+
TORCH_CHECK(gates.sizes() == tokenGradOut.sizes());
422+
423+
DISPATCH_WARPSCAN(gates, outGrad, tokenGradOut, output.data_ptr(), gateGradOut.data_ptr(), true);
424+
}

accelerated_scan/warp.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
cpp_source = """
99
at::Tensor warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse);
10+
void warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut);
1011
"""
1112

1213
module = load_inline(
1314
name='warpscan',
1415
cpp_sources=[cpp_source],
1516
cuda_sources=[cuda_source],
16-
functions=['warpscan_forward'],
17+
functions=['warpscan_forward', 'warpscan_backward'],
1718
verbose=True,
1819
extra_cuda_cflags=[
1920
"-O3",
@@ -26,6 +27,7 @@
2627
]
2728
)
2829
warpscan_forward = module.warpscan_forward
30+
warpscan_backward = module.warpscan_backward
2931

3032
def scan_forward(gates, tokens, reverse=False):
3133
output = torch.zeros_like(tokens)
@@ -56,13 +58,10 @@ def backward(ctx, grad_output):
5658
assert states.is_contiguous()
5759
assert gates.is_contiguous()
5860

59-
padded_shifted_gates = torch.cat([gates, torch.ones_like(gates[:, :, :1])], dim=-1)[:, :, 1:].contiguous()
60-
d_states = scan_forward(padded_shifted_gates, grad_output, reverse=True)
61+
d_gates = torch.empty_like(gates)
62+
d_tokens = torch.empty_like(gates)
63+
warpscan_backward(gates, states, grad_output, d_gates, d_tokens)
6164

62-
padded_outputs = torch.cat([torch.zeros_like(states[:, :, :1]), states], dim=-1)[:, :, :-1]
63-
d_gates = padded_outputs * d_states
64-
65-
d_tokens = d_states
6665
return d_gates, d_tokens
6766

6867

0 commit comments

Comments
 (0)