@@ -347,14 +347,14 @@ warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Te
347347}
348348
349349template <typename Tuple, int offset>
350- __device__ Tuple load_shifted_tuple (const Tuple* ptr, int index, int limit ) {
350+ __device__ Tuple load_shifted_tuple (const Tuple* ptr, int index, int minIdx, int maxIdx ) {
351351 using weight_t = typename Tuple::Type;
352352
353353 const weight_t * rawPtr = reinterpret_cast <const weight_t *>(ptr);
354354 Tuple x;
355355 for (int i = 0 ; i < Tuple::Size; i++) {
356356 const int idx = index * Tuple::Size + i + offset;
357- if (idx >= 0 && idx < limit * Tuple::Size) {
357+ if (idx >= minIdx * Tuple::Size && idx < maxIdx * Tuple::Size) {
358358 x.data [i] = rawPtr[idx];
359359 } else {
360360 x.data [i] = 0.0 ;
@@ -370,7 +370,7 @@ __global__ void scan_grad(
370370 const Tuple* output,
371371 const Tuple* outGrad,
372372 Tuple* gateGradOut,
373- Tuple* valueGradOut ,
373+ Tuple* tokenGradOut ,
374374 const int batch_stride,
375375 const int dim_stride
376376) {
@@ -389,7 +389,10 @@ __global__ void scan_grad(
389389 constexpr int kThreadLast = Tuple::Size - 1 ;
390390 const weight_t kEmptyGate = 1.0 ;
391391 const weight_t kEmptyToken = 0.0 ;
392- const int limit = blockDim .x * kNChunksPerSequence ;
392+
393+ // Limits for loading shifted tuples.
394+ const int minIdx = blockDim .x * kNChunksPerSequence * blockIdx .x ;
395+ const int maxIdx = blockDim .x * kNChunksPerSequence * (blockIdx .x + 1 );
393396
394397 for (int chunk = 0 ; chunk < kNChunksPerSequence ; chunk++) {
395398 const int offset = seqoffset + (kNChunksPerSequence - 1 - chunk) * chunklen;
@@ -400,7 +403,7 @@ __global__ void scan_grad(
400403 }
401404
402405 // Load from global memory.
403- Tuple loadedGate = load_shifted_tuple<Tuple, 1 >(gates, tupleOffset, limit );
406+ Tuple loadedGate = load_shifted_tuple<Tuple, 1 >(gates, tupleOffset, minIdx, maxIdx );
404407 Tuple loadedToken = outGrad[tupleOffset];
405408 loadedGate.reverse ();
406409 loadedToken.reverse ();
@@ -505,9 +508,9 @@ __global__ void scan_grad(
505508 }
506509 }
507510 accToken.reverse ();
508- valueGradOut [tupleOffset] = accToken;
511+ tokenGradOut [tupleOffset] = accToken;
509512
510- Tuple gateGrad = load_shifted_tuple<Tuple, -1 >(output, tupleOffset, limit );
513+ Tuple gateGrad = load_shifted_tuple<Tuple, -1 >(output, tupleOffset, minIdx, maxIdx );
511514 for (int i = 0 ; i < Tuple::Size; i++) {
512515 gateGrad.data [i] = gateGrad.data [i] * accToken.data [i];
513516 }
@@ -520,21 +523,21 @@ __global__ void scan_grad(
520523 }
521524}
522525
523- #define DISPATCH_SCAN_GRAD (weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, output, outGrad, gateGradOut, valueGradOut , batch_stride, dim_stride ) \
526+ #define DISPATCH_SCAN_GRAD (weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, output, outGrad, gateGradOut, tokenGradOut , batch_stride, dim_stride ) \
524527 using AlignedT = AlignedTuple<weight_t , kNStepsPerThread >; \
525528 using UnalignedT = UnalignedTuple<weight_t , kNStepsPerThread >; \
526529 if (kNStepsPerThread == 4 && \
527530 ((long )gates.data_ptr()) % 16 == 0 && \
528531 ((long )output.data_ptr()) % 16 == 0 && \
529532 ((long )outGrad.data_ptr()) % 16 == 0 && \
530533 ((long )gateGradOut.data_ptr()) % 16 == 0 && \
531- ((long )valueGradOut .data_ptr()) % 16 == 0 ) { \
534+ ((long )tokenGradOut .data_ptr()) % 16 == 0 ) { \
532535 scan_grad<AlignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
533536 reinterpret_cast <const AlignedT *>(gates.data_ptr <torch_weight_t >()), \
534537 reinterpret_cast <const AlignedT *>(output.data_ptr <torch_weight_t >()), \
535538 reinterpret_cast <const AlignedT *>(outGrad.data_ptr <torch_weight_t >()), \
536539 reinterpret_cast <AlignedT *>(gateGradOut.data_ptr <torch_weight_t >()), \
537- reinterpret_cast <AlignedT *>(valueGradOut .data_ptr <torch_weight_t >()), \
540+ reinterpret_cast <AlignedT *>(tokenGradOut .data_ptr <torch_weight_t >()), \
538541 batch_stride, dim_stride \
539542 ); \
540543 } else { \
@@ -543,23 +546,23 @@ __global__ void scan_grad(
543546 reinterpret_cast <const UnalignedT *>(output.data_ptr <torch_weight_t >()), \
544547 reinterpret_cast <const UnalignedT *>(outGrad.data_ptr <torch_weight_t >()), \
545548 reinterpret_cast <UnalignedT *>(gateGradOut.data_ptr <torch_weight_t >()), \
546- reinterpret_cast <UnalignedT *>(valueGradOut .data_ptr <torch_weight_t >()), \
549+ reinterpret_cast <UnalignedT *>(tokenGradOut .data_ptr <torch_weight_t >()), \
547550 batch_stride, dim_stride \
548551 ); \
549552 }
550553
551554template <typename weight_t , typename torch_weight_t >
552555void
553556warpscan_grad (const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad,
554- const at::Tensor& gateGradOut, const at::Tensor& valueGradOut ) {
557+ const at::Tensor& gateGradOut, const at::Tensor& tokenGradOut ) {
555558 const auto strides = gates.strides ();
556559 const int batch_stride = strides[0 ];
557560 const int dim_stride = strides[1 ];
558561 CHECK_STRIDE (gates);
559562 CHECK_STRIDE (output);
560563 CHECK_STRIDE (outGrad);
561564 CHECK_STRIDE (gateGradOut);
562- CHECK_STRIDE (valueGradOut );
565+ CHECK_STRIDE (tokenGradOut );
563566
564567 const auto sizes = gates.sizes ();
565568 const int batch_size = sizes[0 ];
@@ -577,7 +580,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
577580 constexpr int kNChunksPerSequence = 1 ;
578581 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
579582 kNChunksPerSequence , grid, kNThreads , stream,
580- gates, output, outGrad, gateGradOut, valueGradOut ,
583+ gates, output, outGrad, gateGradOut, tokenGradOut ,
581584 batch_stride, dim_stride);
582585 } else if (seqlen == 64 ) {
583586 constexpr int kNStepsPerThread = 2 ;
@@ -586,7 +589,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
586589 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
587590 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
588591 kNChunksPerSequence , grid, kNThreads , stream,
589- gates, output, outGrad, gateGradOut, valueGradOut ,
592+ gates, output, outGrad, gateGradOut, tokenGradOut ,
590593 batch_stride, dim_stride);
591594 } else if (seqlen == 128 ) {
592595 constexpr int kNStepsPerThread = 1 ;
@@ -595,7 +598,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
595598 constexpr int kNChunksPerSequence = 1 ;
596599 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
597600 kNChunksPerSequence , grid, kNThreads , stream,
598- gates, output, outGrad, gateGradOut, valueGradOut ,
601+ gates, output, outGrad, gateGradOut, tokenGradOut ,
599602 batch_stride, dim_stride);
600603 } else if (seqlen == 256 ) {
601604 constexpr int kNStepsPerThread = 1 ;
@@ -604,7 +607,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
604607 constexpr int kNChunksPerSequence = 1 ;
605608 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
606609 kNChunksPerSequence , grid, kNThreads , stream,
607- gates, output, outGrad, gateGradOut, valueGradOut ,
610+ gates, output, outGrad, gateGradOut, tokenGradOut ,
608611 batch_stride, dim_stride);
609612 } else if (seqlen == 512 ) {
610613 constexpr int kNStepsPerThread = 1 ;
@@ -613,7 +616,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
613616 constexpr int kNChunksPerSequence = 1 ;
614617 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
615618 kNChunksPerSequence , grid, kNThreads , stream,
616- gates, output, outGrad, gateGradOut, valueGradOut ,
619+ gates, output, outGrad, gateGradOut, tokenGradOut ,
617620 batch_stride, dim_stride);
618621 } else if (seqlen == 1024 ) {
619622 constexpr int kNStepsPerThread = 2 ;
@@ -622,7 +625,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
622625 constexpr int kNChunksPerSequence = 1 ;
623626 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
624627 kNChunksPerSequence , grid, kNThreads , stream,
625- gates, output, outGrad, gateGradOut, valueGradOut ,
628+ gates, output, outGrad, gateGradOut, tokenGradOut ,
626629 batch_stride, dim_stride);
627630 } else if (seqlen == 2048 ) {
628631 constexpr int kNStepsPerThread = 2 ;
@@ -631,7 +634,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
631634 constexpr int kNChunksPerSequence = 1 ;
632635 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
633636 kNChunksPerSequence , grid, kNThreads , stream,
634- gates, output, outGrad, gateGradOut, valueGradOut ,
637+ gates, output, outGrad, gateGradOut, tokenGradOut ,
635638 batch_stride, dim_stride);
636639 } else if (seqlen == 4096 ) {
637640 constexpr int kNStepsPerThread = 4 ;
@@ -640,7 +643,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
640643 constexpr int kNChunksPerSequence = 1 ;
641644 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
642645 kNChunksPerSequence , grid, kNThreads , stream,
643- gates, output, outGrad, gateGradOut, valueGradOut ,
646+ gates, output, outGrad, gateGradOut, tokenGradOut ,
644647 batch_stride, dim_stride);
645648 } else if (seqlen == 8192 ) {
646649 constexpr int kNStepsPerThread = 4 ;
@@ -649,7 +652,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
649652 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
650653 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
651654 kNChunksPerSequence , grid, kNThreads , stream,
652- gates, output, outGrad, gateGradOut, valueGradOut ,
655+ gates, output, outGrad, gateGradOut, tokenGradOut ,
653656 batch_stride, dim_stride);
654657 } else if (seqlen == 16384 ) {
655658 constexpr int kNStepsPerThread = 4 ;
@@ -658,7 +661,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
658661 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
659662 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
660663 kNChunksPerSequence , grid, kNThreads , stream,
661- gates, output, outGrad, gateGradOut, valueGradOut ,
664+ gates, output, outGrad, gateGradOut, tokenGradOut ,
662665 batch_stride, dim_stride);
663666 } else if (seqlen == 32768 ) {
664667 constexpr int kNStepsPerThread = 4 ;
@@ -667,7 +670,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
667670 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
668671 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
669672 kNChunksPerSequence , grid, kNThreads , stream,
670- gates, output, outGrad, gateGradOut, valueGradOut ,
673+ gates, output, outGrad, gateGradOut, tokenGradOut ,
671674 batch_stride, dim_stride);
672675 } else if (seqlen == 65536 ) {
673676 constexpr int kNStepsPerThread = 4 ;
@@ -676,35 +679,35 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
676679 int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence ;
677680 DISPATCH_SCAN_GRAD (weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock ,
678681 kNChunksPerSequence , grid, kNThreads , stream,
679- gates, output, outGrad, gateGradOut, valueGradOut ,
682+ gates, output, outGrad, gateGradOut, tokenGradOut ,
680683 batch_stride, dim_stride);
681684 } else {
682685 TORCH_CHECK (false && " seqlen must be a power of 2, >= 32, <= 65536" );
683686 }
684687}
685688
686689void
687- warpscan_backward (const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut ) {
690+ warpscan_backward (const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& tokenGradOut ) {
688691 TORCH_CHECK (gates.is_cuda ());
689692 TORCH_CHECK (output.is_cuda ());
690693 TORCH_CHECK (outGrad.is_cuda ());
691694 TORCH_CHECK (gateGradOut.is_contiguous ());
692- TORCH_CHECK (valueGradOut .is_contiguous ());
695+ TORCH_CHECK (tokenGradOut .is_contiguous ());
693696 TORCH_CHECK (gates.scalar_type () == output.scalar_type ());
694697 TORCH_CHECK (gates.scalar_type () == outGrad.scalar_type ());
695698 TORCH_CHECK (gates.scalar_type () == gateGradOut.scalar_type ());
696- TORCH_CHECK (gates.scalar_type () == valueGradOut .scalar_type ());
699+ TORCH_CHECK (gates.scalar_type () == tokenGradOut .scalar_type ());
697700 TORCH_CHECK (gates.sizes () == output.sizes ());
698701 TORCH_CHECK (gates.sizes () == outGrad.sizes ());
699702 TORCH_CHECK (gates.sizes () == gateGradOut.sizes ());
700- TORCH_CHECK (gates.sizes () == valueGradOut .sizes ());
703+ TORCH_CHECK (gates.sizes () == tokenGradOut .sizes ());
701704
702705 if (gates.scalar_type () == at::ScalarType::BFloat16) {
703- warpscan_grad<__nv_bfloat16, at::BFloat16>(gates, output, outGrad, gateGradOut, valueGradOut );
706+ warpscan_grad<__nv_bfloat16, at::BFloat16>(gates, output, outGrad, gateGradOut, tokenGradOut );
704707 } else if (gates.scalar_type () == at::ScalarType::Half) {
705- warpscan_grad<__half, at::Half>(gates, output, outGrad, gateGradOut, valueGradOut );
708+ warpscan_grad<__half, at::Half>(gates, output, outGrad, gateGradOut, tokenGradOut );
706709 } else if (gates.scalar_type () == at::ScalarType::Float) {
707- warpscan_grad<float , float >(gates, output, outGrad, gateGradOut, valueGradOut );
710+ warpscan_grad<float , float >(gates, output, outGrad, gateGradOut, tokenGradOut );
708711 } else {
709712 TORCH_CHECK (false && " Unsupported tensor dtype: expecting bfloat16, float16 or float32" );
710713 }
0 commit comments