Skip to content

Commit 8a42373

Browse files
committed
rename; fix bounds
1 parent 36d1abe commit 8a42373

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

accelerated_scan/warp.cuh

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,14 @@ warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Te
347347
}
348348

349349
template <typename Tuple, int offset>
350-
__device__ Tuple load_shifted_tuple(const Tuple* ptr, int index, int limit) {
350+
__device__ Tuple load_shifted_tuple(const Tuple* ptr, int index, int minIdx, int maxIdx) {
351351
using weight_t = typename Tuple::Type;
352352

353353
const weight_t* rawPtr = reinterpret_cast<const weight_t *>(ptr);
354354
Tuple x;
355355
for (int i = 0; i < Tuple::Size; i++) {
356356
const int idx = index * Tuple::Size + i + offset;
357-
if (idx >= 0 && idx < limit * Tuple::Size) {
357+
if (idx >= minIdx * Tuple::Size && idx < maxIdx * Tuple::Size) {
358358
x.data[i] = rawPtr[idx];
359359
} else {
360360
x.data[i] = 0.0;
@@ -370,7 +370,7 @@ __global__ void scan_grad(
370370
const Tuple* output,
371371
const Tuple* outGrad,
372372
Tuple* gateGradOut,
373-
Tuple* valueGradOut,
373+
Tuple* tokenGradOut,
374374
const int batch_stride,
375375
const int dim_stride
376376
) {
@@ -389,7 +389,10 @@ __global__ void scan_grad(
389389
constexpr int kThreadLast = Tuple::Size - 1;
390390
const weight_t kEmptyGate = 1.0;
391391
const weight_t kEmptyToken = 0.0;
392-
const int limit = blockDim.x * kNChunksPerSequence;
392+
393+
// Limits for loading shifted tuples.
394+
const int minIdx = blockDim.x * kNChunksPerSequence * blockIdx.x;
395+
const int maxIdx = blockDim.x * kNChunksPerSequence * (blockIdx.x + 1);
393396

394397
for (int chunk = 0; chunk < kNChunksPerSequence; chunk++) {
395398
const int offset = seqoffset + (kNChunksPerSequence - 1 - chunk) * chunklen;
@@ -400,7 +403,7 @@ __global__ void scan_grad(
400403
}
401404

402405
// Load from global memory.
403-
Tuple loadedGate = load_shifted_tuple<Tuple, 1>(gates, tupleOffset, limit);
406+
Tuple loadedGate = load_shifted_tuple<Tuple, 1>(gates, tupleOffset, minIdx, maxIdx);
404407
Tuple loadedToken = outGrad[tupleOffset];
405408
loadedGate.reverse();
406409
loadedToken.reverse();
@@ -505,9 +508,9 @@ __global__ void scan_grad(
505508
}
506509
}
507510
accToken.reverse();
508-
valueGradOut[tupleOffset] = accToken;
511+
tokenGradOut[tupleOffset] = accToken;
509512

510-
Tuple gateGrad = load_shifted_tuple<Tuple, -1>(output, tupleOffset, limit);
513+
Tuple gateGrad = load_shifted_tuple<Tuple, -1>(output, tupleOffset, minIdx, maxIdx);
511514
for (int i = 0; i < Tuple::Size; i++) {
512515
gateGrad.data[i] = gateGrad.data[i] * accToken.data[i];
513516
}
@@ -520,21 +523,21 @@ __global__ void scan_grad(
520523
}
521524
}
522525

523-
#define DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, output, outGrad, gateGradOut, valueGradOut, batch_stride, dim_stride) \
526+
#define DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, output, outGrad, gateGradOut, tokenGradOut, batch_stride, dim_stride) \
524527
using AlignedT = AlignedTuple<weight_t, kNStepsPerThread>; \
525528
using UnalignedT = UnalignedTuple<weight_t, kNStepsPerThread>; \
526529
if (kNStepsPerThread == 4 && \
527530
((long)gates.data_ptr()) % 16 == 0 && \
528531
((long)output.data_ptr()) % 16 == 0 && \
529532
((long)outGrad.data_ptr()) % 16 == 0 && \
530533
((long)gateGradOut.data_ptr()) % 16 == 0 && \
531-
((long)valueGradOut.data_ptr()) % 16 == 0) { \
534+
((long)tokenGradOut.data_ptr()) % 16 == 0) { \
532535
scan_grad<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
533536
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
534537
reinterpret_cast<const AlignedT *>(output.data_ptr<torch_weight_t>()), \
535538
reinterpret_cast<const AlignedT *>(outGrad.data_ptr<torch_weight_t>()), \
536539
reinterpret_cast<AlignedT *>(gateGradOut.data_ptr<torch_weight_t>()), \
537-
reinterpret_cast<AlignedT *>(valueGradOut.data_ptr<torch_weight_t>()), \
540+
reinterpret_cast<AlignedT *>(tokenGradOut.data_ptr<torch_weight_t>()), \
538541
batch_stride, dim_stride \
539542
); \
540543
} else { \
@@ -543,23 +546,23 @@ __global__ void scan_grad(
543546
reinterpret_cast<const UnalignedT *>(output.data_ptr<torch_weight_t>()), \
544547
reinterpret_cast<const UnalignedT *>(outGrad.data_ptr<torch_weight_t>()), \
545548
reinterpret_cast<UnalignedT *>(gateGradOut.data_ptr<torch_weight_t>()), \
546-
reinterpret_cast<UnalignedT *>(valueGradOut.data_ptr<torch_weight_t>()), \
549+
reinterpret_cast<UnalignedT *>(tokenGradOut.data_ptr<torch_weight_t>()), \
547550
batch_stride, dim_stride \
548551
); \
549552
}
550553

551554
template <typename weight_t, typename torch_weight_t>
552555
void
553556
warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad,
554-
const at::Tensor& gateGradOut, const at::Tensor& valueGradOut) {
557+
const at::Tensor& gateGradOut, const at::Tensor& tokenGradOut) {
555558
const auto strides = gates.strides();
556559
const int batch_stride = strides[0];
557560
const int dim_stride = strides[1];
558561
CHECK_STRIDE(gates);
559562
CHECK_STRIDE(output);
560563
CHECK_STRIDE(outGrad);
561564
CHECK_STRIDE(gateGradOut);
562-
CHECK_STRIDE(valueGradOut);
565+
CHECK_STRIDE(tokenGradOut);
563566

564567
const auto sizes = gates.sizes();
565568
const int batch_size = sizes[0];
@@ -577,7 +580,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
577580
constexpr int kNChunksPerSequence = 1;
578581
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
579582
kNChunksPerSequence, grid, kNThreads, stream,
580-
gates, output, outGrad, gateGradOut, valueGradOut,
583+
gates, output, outGrad, gateGradOut, tokenGradOut,
581584
batch_stride, dim_stride);
582585
} else if (seqlen == 64) {
583586
constexpr int kNStepsPerThread = 2;
@@ -586,7 +589,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
586589
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
587590
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
588591
kNChunksPerSequence, grid, kNThreads, stream,
589-
gates, output, outGrad, gateGradOut, valueGradOut,
592+
gates, output, outGrad, gateGradOut, tokenGradOut,
590593
batch_stride, dim_stride);
591594
} else if (seqlen == 128) {
592595
constexpr int kNStepsPerThread = 1;
@@ -595,7 +598,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
595598
constexpr int kNChunksPerSequence = 1;
596599
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
597600
kNChunksPerSequence, grid, kNThreads, stream,
598-
gates, output, outGrad, gateGradOut, valueGradOut,
601+
gates, output, outGrad, gateGradOut, tokenGradOut,
599602
batch_stride, dim_stride);
600603
} else if (seqlen == 256) {
601604
constexpr int kNStepsPerThread = 1;
@@ -604,7 +607,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
604607
constexpr int kNChunksPerSequence = 1;
605608
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
606609
kNChunksPerSequence, grid, kNThreads, stream,
607-
gates, output, outGrad, gateGradOut, valueGradOut,
610+
gates, output, outGrad, gateGradOut, tokenGradOut,
608611
batch_stride, dim_stride);
609612
} else if (seqlen == 512) {
610613
constexpr int kNStepsPerThread = 1;
@@ -613,7 +616,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
613616
constexpr int kNChunksPerSequence = 1;
614617
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
615618
kNChunksPerSequence, grid, kNThreads, stream,
616-
gates, output, outGrad, gateGradOut, valueGradOut,
619+
gates, output, outGrad, gateGradOut, tokenGradOut,
617620
batch_stride, dim_stride);
618621
} else if (seqlen == 1024) {
619622
constexpr int kNStepsPerThread = 2;
@@ -622,7 +625,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
622625
constexpr int kNChunksPerSequence = 1;
623626
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
624627
kNChunksPerSequence, grid, kNThreads, stream,
625-
gates, output, outGrad, gateGradOut, valueGradOut,
628+
gates, output, outGrad, gateGradOut, tokenGradOut,
626629
batch_stride, dim_stride);
627630
} else if (seqlen == 2048) {
628631
constexpr int kNStepsPerThread = 2;
@@ -631,7 +634,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
631634
constexpr int kNChunksPerSequence = 1;
632635
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
633636
kNChunksPerSequence, grid, kNThreads, stream,
634-
gates, output, outGrad, gateGradOut, valueGradOut,
637+
gates, output, outGrad, gateGradOut, tokenGradOut,
635638
batch_stride, dim_stride);
636639
} else if (seqlen == 4096) {
637640
constexpr int kNStepsPerThread = 4;
@@ -640,7 +643,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
640643
constexpr int kNChunksPerSequence = 1;
641644
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
642645
kNChunksPerSequence, grid, kNThreads, stream,
643-
gates, output, outGrad, gateGradOut, valueGradOut,
646+
gates, output, outGrad, gateGradOut, tokenGradOut,
644647
batch_stride, dim_stride);
645648
} else if (seqlen == 8192) {
646649
constexpr int kNStepsPerThread = 4;
@@ -649,7 +652,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
649652
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
650653
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
651654
kNChunksPerSequence, grid, kNThreads, stream,
652-
gates, output, outGrad, gateGradOut, valueGradOut,
655+
gates, output, outGrad, gateGradOut, tokenGradOut,
653656
batch_stride, dim_stride);
654657
} else if (seqlen == 16384) {
655658
constexpr int kNStepsPerThread = 4;
@@ -658,7 +661,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
658661
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
659662
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
660663
kNChunksPerSequence, grid, kNThreads, stream,
661-
gates, output, outGrad, gateGradOut, valueGradOut,
664+
gates, output, outGrad, gateGradOut, tokenGradOut,
662665
batch_stride, dim_stride);
663666
} else if (seqlen == 32768) {
664667
constexpr int kNStepsPerThread = 4;
@@ -667,7 +670,7 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
667670
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
668671
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
669672
kNChunksPerSequence, grid, kNThreads, stream,
670-
gates, output, outGrad, gateGradOut, valueGradOut,
673+
gates, output, outGrad, gateGradOut, tokenGradOut,
671674
batch_stride, dim_stride);
672675
} else if (seqlen == 65536) {
673676
constexpr int kNStepsPerThread = 4;
@@ -676,35 +679,35 @@ warpscan_grad(const at::Tensor &gates, const at::Tensor &output, const at::Tenso
676679
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
677680
DISPATCH_SCAN_GRAD(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
678681
kNChunksPerSequence, grid, kNThreads, stream,
679-
gates, output, outGrad, gateGradOut, valueGradOut,
682+
gates, output, outGrad, gateGradOut, tokenGradOut,
680683
batch_stride, dim_stride);
681684
} else {
682685
TORCH_CHECK(false && "seqlen must be a power of 2, >= 32, <= 65536");
683686
}
684687
}
685688

686689
void
687-
warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut) {
690+
warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& tokenGradOut) {
688691
TORCH_CHECK(gates.is_cuda());
689692
TORCH_CHECK(output.is_cuda());
690693
TORCH_CHECK(outGrad.is_cuda());
691694
TORCH_CHECK(gateGradOut.is_contiguous());
692-
TORCH_CHECK(valueGradOut.is_contiguous());
695+
TORCH_CHECK(tokenGradOut.is_contiguous());
693696
TORCH_CHECK(gates.scalar_type() == output.scalar_type());
694697
TORCH_CHECK(gates.scalar_type() == outGrad.scalar_type());
695698
TORCH_CHECK(gates.scalar_type() == gateGradOut.scalar_type());
696-
TORCH_CHECK(gates.scalar_type() == valueGradOut.scalar_type());
699+
TORCH_CHECK(gates.scalar_type() == tokenGradOut.scalar_type());
697700
TORCH_CHECK(gates.sizes() == output.sizes());
698701
TORCH_CHECK(gates.sizes() == outGrad.sizes());
699702
TORCH_CHECK(gates.sizes() == gateGradOut.sizes());
700-
TORCH_CHECK(gates.sizes() == valueGradOut.sizes());
703+
TORCH_CHECK(gates.sizes() == tokenGradOut.sizes());
701704

702705
if (gates.scalar_type() == at::ScalarType::BFloat16) {
703-
warpscan_grad<__nv_bfloat16, at::BFloat16>(gates, output, outGrad, gateGradOut, valueGradOut);
706+
warpscan_grad<__nv_bfloat16, at::BFloat16>(gates, output, outGrad, gateGradOut, tokenGradOut);
704707
} else if (gates.scalar_type() == at::ScalarType::Half) {
705-
warpscan_grad<__half, at::Half>(gates, output, outGrad, gateGradOut, valueGradOut);
708+
warpscan_grad<__half, at::Half>(gates, output, outGrad, gateGradOut, tokenGradOut);
706709
} else if (gates.scalar_type() == at::ScalarType::Float) {
707-
warpscan_grad<float, float>(gates, output, outGrad, gateGradOut, valueGradOut);
710+
warpscan_grad<float, float>(gates, output, outGrad, gateGradOut, tokenGradOut);
708711
} else {
709712
TORCH_CHECK(false && "Unsupported tensor dtype: expecting bfloat16, float16 or float32");
710713
}

0 commit comments

Comments
 (0)