44#include < ATen/cuda/CUDAContext.h>
55#include < torch/extension.h>
66
7+ #define CHECK_STRIDE (x ) TORCH_CHECK(x.stride(-1 ) == 1 || x.size(-1 ) == 1 );
8+
79template <typename weight_t , int N>
810class UnalignedTuple {
911public:
@@ -26,11 +28,33 @@ template<typename T, int N>
2628class alignas (16 ) AlignedTuple : public UnalignedTuple<T, N> {
2729};
2830
29- template <typename Tuple, int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence >
31+ template <typename Tuple, int offset>
32+ __device__ Tuple load_shifted_tuple (const Tuple* ptr, int index, int minIdx, int maxIdx) {
33+ using weight_t = typename Tuple::Type;
34+
35+ const weight_t * rawPtr = reinterpret_cast <const weight_t *>(ptr);
36+ Tuple x;
37+ for (int i = 0 ; i < Tuple::Size; i++) {
38+ const int idx = index * Tuple::Size + i + offset;
39+ if (idx >= minIdx * Tuple::Size && idx < maxIdx * Tuple::Size) {
40+ x.data [i] = rawPtr[idx];
41+ } else {
42+ x.data [i] = 0.0 ;
43+ }
44+ }
45+
46+ return x;
47+ }
48+
49+ template <typename Tuple, int kNThreadsPerWarp , int kNWarpsPerBlock , int kNChunksPerSequence , bool backward>
3050__global__ void scan (
3151 const Tuple* gates,
3252 const Tuple* tokens,
3353 Tuple* result,
54+ // Only passed if backward is True.
55+ const Tuple* output,
56+ Tuple* gateGradOut,
57+ // Shape information
3458 const int batch_stride,
3559 const int dim_stride,
3660 const bool reverse
@@ -51,6 +75,10 @@ __global__ void scan(
5175 const weight_t kEmptyGate = 1.0 ;
5276 const weight_t kEmptyToken = 0.0 ;
5377
78+ // Limits for loading shifted tuples during backward pass.
79+ const int minIdx = seqoffset / Tuple::Size;
80+ const int maxIdx = minIdx + blockDim .x * kNChunksPerSequence ;
81+
5482 //
5583 // Read from global memory.
5684 // Scan sequentially in thread registers (level 0).
@@ -64,7 +92,12 @@ __global__ void scan(
6492 __syncthreads ();
6593 }
6694
67- Tuple loadedGate = gates[tupleOffset];
95+ Tuple loadedGate;
96+ if (backward) {
97+ loadedGate = load_shifted_tuple<Tuple, 1 >(gates, tupleOffset, minIdx, maxIdx);
98+ } else {
99+ loadedGate = gates[tupleOffset];
100+ }
68101 Tuple loadedToken = tokens[tupleOffset];
69102 if (reverse) {
70103 loadedGate.reverse ();
@@ -174,43 +207,68 @@ __global__ void scan(
174207 }
175208 result[tupleOffset] = accToken;
176209
210+ if (backward) {
211+ Tuple gateGrad = load_shifted_tuple<Tuple, -1 >(output, tupleOffset, minIdx, maxIdx);
212+ for (int i = 0 ; i < Tuple::Size; i++) {
213+ gateGrad.data [i] = gateGrad.data [i] * accToken.data [i];
214+ }
215+ gateGradOut[tupleOffset] = gateGrad;
216+ }
217+
177218 if (laneId == kWarpLast && warpId == kBlockLast ) {
178219 chunkAccGate = accGate.data [kThreadLast ];
179220 chunkAccToken = accToken.data [kThreadLast ];
180221 }
181222 }
182223}
183224
184- #define DISPATCH_SCAN (weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse ) \
225+ #define DISPATCH_SCAN_INNER (TupleT, backward, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse ) \
226+ scan<TupleT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , backward><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
227+ reinterpret_cast <const TupleT *>(gates.data_ptr<torch_weight_t >()), \
228+ reinterpret_cast <const TupleT *>(tokens.data_ptr<torch_weight_t >()), \
229+ reinterpret_cast <TupleT *>(out.data_ptr<torch_weight_t >()), \
230+ reinterpret_cast <const TupleT *>(output), \
231+ reinterpret_cast <TupleT *>(gateGradOut), \
232+ batch_stride, dim_stride, reverse \
233+ );
234+
235+ #define DISPATCH_SCAN (weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse ) \
185236 using AlignedT = AlignedTuple<weight_t , kNStepsPerThread >; \
186237 using UnalignedT = UnalignedTuple<weight_t , kNStepsPerThread >; \
187238 if (kNStepsPerThread == 4 && \
188239 ((long )gates.data_ptr()) % 16 == 0 && \
189240 ((long )tokens.data_ptr()) % 16 == 0 && \
190- ((long )out.data_ptr()) % 16 == 0 ) { \
191- scan<AlignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
192- reinterpret_cast <const AlignedT *>(gates.data_ptr <torch_weight_t >()), \
193- reinterpret_cast <const AlignedT *>(tokens.data_ptr <torch_weight_t >()), \
194- reinterpret_cast <AlignedT *>(out.data_ptr <torch_weight_t >()), \
195- batch_stride, dim_stride, reverse \
196- ); \
241+ ((long )out.data_ptr()) % 16 == 0 && \
242+ ((long )output) % 16 == 0 && \
243+ ((long )gateGradOut) % 16 == 0 ) { \
244+ if (output) { \
245+ DISPATCH_SCAN_INNER (AlignedT, true , weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
246+ } else { \
247+ DISPATCH_SCAN_INNER (AlignedT, false , weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
248+ } \
197249 } else { \
198- scan<UnalignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
199- reinterpret_cast <const UnalignedT*>(gates.data_ptr <torch_weight_t >()), \
200- reinterpret_cast <const UnalignedT*>(tokens.data_ptr <torch_weight_t >()), \
201- reinterpret_cast <UnalignedT *>(out.data_ptr <torch_weight_t >()), \
202- batch_stride, dim_stride, reverse \
203- ); \
250+ if (output) { \
251+ DISPATCH_SCAN_INNER (UnalignedT, true , weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
252+ } else { \
253+ DISPATCH_SCAN_INNER (UnalignedT, false , weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
254+ } \
204255 }
205256
206257template <typename weight_t , typename torch_weight_t >
207258void
208- warpscan (const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
259+ warpscan (
260+ const at::Tensor &gates,
261+ const at::Tensor &tokens,
262+ const at::Tensor &out,
263+ const void *output,
264+ void *gateGradOut,
265+ const bool reverse
266+ ) {
209267 const auto strides = tokens.strides ();
210268 const int batch_stride = strides[0 ];
211269 const int dim_stride = strides[1 ];
212- TORCH_CHECK (tokens. stride (- 1 ) == 1 || tokens. size (- 1 ) == 1 );
213- TORCH_CHECK (gates. stride (- 1 ) == 1 || gates. size (- 1 ) == 1 );
270+ CHECK_STRIDE (tokens);
271+ CHECK_STRIDE (gates);
214272
215273 const auto sizes = tokens.sizes ();
216274 const int batch_size = sizes[0 ];
@@ -227,119 +285,140 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
227285 int kNThreads = seqlen / kNStepsPerThread ;
228286 constexpr int kNChunksPerSequence = 1 ;
229287 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
230- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
288+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
231289 batch_stride, dim_stride, reverse);
232290 } else if (seqlen == 64 ) {
233291 constexpr int kNStepsPerThread = 2 ;
234292 constexpr int kNWarpsPerBlock = 1 ;
235293 constexpr int kNChunksPerSequence = 1 ;
236294 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
237295 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
238- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
296+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
239297 batch_stride, dim_stride, reverse);
240298 } else if (seqlen == 128 ) {
241299 constexpr int kNStepsPerThread = 1 ;
242300 constexpr int kNWarpsPerBlock = 4 ;
243301 int kNThreads = seqlen / kNStepsPerThread ;
244302 constexpr int kNChunksPerSequence = 1 ;
245303 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
246- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
304+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
247305 batch_stride, dim_stride, reverse);
248306 } else if (seqlen == 256 ) {
249307 constexpr int kNStepsPerThread = 1 ;
250308 constexpr int kNWarpsPerBlock = 8 ;
251309 int kNThreads = seqlen / kNStepsPerThread ;
252310 constexpr int kNChunksPerSequence = 1 ;
253311 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
254- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
312+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
255313 batch_stride, dim_stride, reverse);
256314 } else if (seqlen == 512 ) {
257315 constexpr int kNStepsPerThread = 1 ;
258316 constexpr int kNWarpsPerBlock = 16 ;
259317 int kNThreads = seqlen / kNStepsPerThread ;
260318 constexpr int kNChunksPerSequence = 1 ;
261319 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
262- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
320+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
263321 batch_stride, dim_stride, reverse);
264322 } else if (seqlen == 1024 ) {
265323 constexpr int kNStepsPerThread = 2 ;
266324 constexpr int kNWarpsPerBlock = 16 ;
267325 int kNThreads = seqlen / kNStepsPerThread ;
268326 constexpr int kNChunksPerSequence = 1 ;
269327 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
270- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
328+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
271329 batch_stride, dim_stride, reverse);
272330 } else if (seqlen == 2048 ) {
273331 constexpr int kNStepsPerThread = 2 ;
274332 constexpr int kNWarpsPerBlock = 32 ;
275333 int kNThreads = seqlen / kNStepsPerThread ;
276334 constexpr int kNChunksPerSequence = 1 ;
277335 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
278- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
336+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
279337 batch_stride, dim_stride, reverse);
280338 } else if (seqlen == 4096 ) {
281339 constexpr int kNStepsPerThread = 4 ;
282340 constexpr int kNWarpsPerBlock = 32 ;
283341 int kNThreads = seqlen / kNStepsPerThread ;
284342 constexpr int kNChunksPerSequence = 1 ;
285343 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
286- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
344+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
287345 batch_stride, dim_stride, reverse);
288346 } else if (seqlen == 8192 ) {
289347 constexpr int kNStepsPerThread = 4 ;
290348 constexpr int kNWarpsPerBlock = 32 ;
291349 constexpr int kNChunksPerSequence = 2 ;
292350 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
293351 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
294- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
352+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
295353 batch_stride, dim_stride, reverse);
296354 } else if (seqlen == 16384 ) {
297355 constexpr int kNStepsPerThread = 4 ;
298356 constexpr int kNWarpsPerBlock = 32 ;
299357 constexpr int kNChunksPerSequence = 4 ;
300358 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
301359 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
302- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
360+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
303361 batch_stride, dim_stride, reverse);
304362 } else if (seqlen == 32768 ) {
305363 constexpr int kNStepsPerThread = 4 ;
306364 constexpr int kNWarpsPerBlock = 32 ;
307365 constexpr int kNChunksPerSequence = 8 ;
308366 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
309367 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
310- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
368+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
311369 batch_stride, dim_stride, reverse);
312370 } else if (seqlen == 65536 ) {
313371 constexpr int kNStepsPerThread = 4 ;
314372 constexpr int kNWarpsPerBlock = 32 ;
315373 constexpr int kNChunksPerSequence = 16 ;
316374 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
317375 DISPATCH_SCAN (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
318- kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out,
376+ kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut,
319377 batch_stride, dim_stride, reverse);
320378 } else {
321379 TORCH_CHECK (false && " seqlen must be a power of 2, >= 32, <= 65536" );
322380 }
323381}
324382
383+ #define DISPATCH_WARPSCAN (gates, ...) \
384+ if (gates.scalar_type() == at::ScalarType::BFloat16) { \
385+ warpscan<__nv_bfloat16, at::BFloat16>(gates, __VA_ARGS__); \
386+ } else if (gates.scalar_type() == at::ScalarType::Half) { \
387+ warpscan<__half, at::Half>(gates, __VA_ARGS__); \
388+ } else if (gates.scalar_type() == at::ScalarType::Float) { \
389+ warpscan<float , float >(gates, __VA_ARGS__); \
390+ } else { \
391+ TORCH_CHECK (false && " Unsupported tensor dtype: expecting bfloat16, float16 or float32" ); \
392+ }
393+
325394at::Tensor
326395warpscan_forward (const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
327396 TORCH_CHECK (tokens.is_cuda ());
328397 TORCH_CHECK (gates.is_cuda ());
329398 TORCH_CHECK (tokens.is_contiguous ());
330399 TORCH_CHECK (gates.is_contiguous ());
400+ TORCH_CHECK (tokens.scalar_type () == gates.scalar_type ());
401+ TORCH_CHECK (tokens.scalar_type () == out.scalar_type ());
331402
332- if (tokens.scalar_type () == at::ScalarType::BFloat16) {
333- TORCH_CHECK (gates.scalar_type () == at::ScalarType::BFloat16);
334- warpscan<__nv_bfloat16, at::BFloat16>(gates, tokens, out, reverse);
335- } else if (tokens.scalar_type () == at::ScalarType::Half) {
336- TORCH_CHECK (gates.scalar_type () == at::ScalarType::Half);
337- warpscan<__half, at::Half>(gates, tokens, out, reverse);
338- } else if (tokens.scalar_type () == at::ScalarType::Float) {
339- TORCH_CHECK (gates.scalar_type () == at::ScalarType::Float);
340- warpscan<float , float >(gates, tokens, out, reverse);
341- } else {
342- TORCH_CHECK (false && " Unsupported tensor dtype: expecting bfloat16, float16 or float32" );
343- }
403+ DISPATCH_WARPSCAN (gates, tokens, out, nullptr , nullptr , reverse);
344404 return out;
345- }
405+ }
406+
407+ void
408+ warpscan_backward (const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& tokenGradOut) {
409+ TORCH_CHECK (gates.is_cuda ());
410+ TORCH_CHECK (output.is_cuda ());
411+ TORCH_CHECK (outGrad.is_cuda ());
412+ TORCH_CHECK (gateGradOut.is_contiguous ());
413+ TORCH_CHECK (tokenGradOut.is_contiguous ());
414+ TORCH_CHECK (gates.scalar_type () == output.scalar_type ());
415+ TORCH_CHECK (gates.scalar_type () == outGrad.scalar_type ());
416+ TORCH_CHECK (gates.scalar_type () == gateGradOut.scalar_type ());
417+ TORCH_CHECK (gates.scalar_type () == tokenGradOut.scalar_type ());
418+ TORCH_CHECK (gates.sizes () == output.sizes ());
419+ TORCH_CHECK (gates.sizes () == outGrad.sizes ());
420+ TORCH_CHECK (gates.sizes () == gateGradOut.sizes ());
421+ TORCH_CHECK (gates.sizes () == tokenGradOut.sizes ());
422+
423+ DISPATCH_WARPSCAN (gates, outGrad, tokenGradOut, output.data_ptr (), gateGradOut.data_ptr (), true );
424+ }
0 commit comments