Skip to content

Commit b8a5159

Browse files
authored
[None][feat] Enable PDL for indexer topK (#9843)
Signed-off-by: Christina Zhang <[email protected]>
1 parent d147ad0 commit b8a5159

File tree

1 file changed

+68
-11
lines changed

1 file changed

+68
-11
lines changed

cpp/tensorrt_llm/kernels/indexerTopK.cu

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,9 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
589589
int const* rowStarts, int const* rowEnds, int* outIndices, int stride0, int stride1, int const topK,
590590
int const offsetIndex)
591591
{
592+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
593+
cudaGridDependencySynchronize();
594+
#endif
592595
// The number of bins in the histogram.
593596
static constexpr int kNumBins = 2048;
594597

@@ -605,13 +608,19 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
605608

606609
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
607610
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
611+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
612+
cudaTriggerProgrammaticLaunchCompletion();
613+
#endif
608614
}
609615

610616
template <int kNumThreadsPerBlock, bool useRadixSort, bool multipleBlocksPerRow = false, bool mergeBlocks = false>
611617
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(float const* logits, int const* seqLens,
612618
int* outIndices, int stride0, int stride1, int const topK, int next_n, float* outLogits = nullptr,
613619
int const numBlocksToMerge = 0, int const* indices = nullptr)
614620
{
621+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
622+
cudaGridDependencySynchronize();
623+
#endif
615624
// The number of bins in the histogram.
616625
static constexpr int kNumBins = 2048;
617626

@@ -646,6 +655,9 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(f
646655

647656
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort, multipleBlocksPerRow, mergeBlocks>(
648657
indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK);
658+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
659+
cudaTriggerProgrammaticLaunchCompletion();
660+
#endif
649661
}
650662

651663
void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, float* outLogitsAux,
@@ -660,28 +672,73 @@ void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indic
660672
if (numColumns < kSortingAlgorithmThreshold)
661673
{
662674
// Use insertion sort
663-
topKPerRowDecode<kNumThreadsPerBlock, false><<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
664-
logits, seqLens, indices, stride0, stride1, topK, next_n);
675+
auto* kernel_instance = &topKPerRowDecode<kNumThreadsPerBlock, false>;
676+
677+
cudaLaunchConfig_t config;
678+
config.gridDim = numRows;
679+
config.blockDim = kNumThreadsPerBlock;
680+
config.dynamicSmemBytes = topK * sizeof(int32_t);
681+
config.stream = stream;
682+
cudaLaunchAttribute attrs[1];
683+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
684+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
685+
config.numAttrs = 1;
686+
config.attrs = attrs;
687+
688+
cudaLaunchKernelEx(
689+
&config, kernel_instance, logits, seqLens, indices, stride0, stride1, topK, next_n, nullptr, 0, nullptr);
665690
}
666691
else if (numColumns < kSplitWorkThreshold)
667692
{
668693
// From this threshold, use radix sort instead
669-
topKPerRowDecode<kNumThreadsPerBlock, true><<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
670-
logits, seqLens, indices, stride0, stride1, topK, next_n);
694+
auto* kernel_instance = &topKPerRowDecode<kNumThreadsPerBlock, true>;
695+
696+
cudaLaunchConfig_t config;
697+
config.gridDim = numRows;
698+
config.blockDim = kNumThreadsPerBlock;
699+
config.dynamicSmemBytes = topK * sizeof(int32_t);
700+
config.stream = stream;
701+
cudaLaunchAttribute attrs[1];
702+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
703+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
704+
config.numAttrs = 1;
705+
config.attrs = attrs;
706+
707+
cudaLaunchKernelEx(
708+
&config, kernel_instance, logits, seqLens, indices, stride0, stride1, topK, next_n, nullptr, 0, nullptr);
671709
}
672710
else
673711
{
674712
// Long sequences are run in two steps
675713
constexpr auto multipleBlocksPerRowConfig = 10;
676-
677-
topKPerRowDecode<kNumThreadsPerBlock, true, true>
678-
<<<dim3(numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock, 2 * topK * sizeof(int32_t), stream>>>(
679-
logits, seqLens, outIndicesAux, stride0, stride1, topK, next_n, outLogitsAux);
714+
auto* kernel_instance_part1 = &topKPerRowDecode<kNumThreadsPerBlock, true, true>;
715+
cudaLaunchConfig_t config_part1;
716+
config_part1.gridDim = dim3(numRows, multipleBlocksPerRowConfig);
717+
config_part1.blockDim = kNumThreadsPerBlock;
718+
config_part1.dynamicSmemBytes = 2 * topK * sizeof(int32_t);
719+
config_part1.stream = stream;
720+
cudaLaunchAttribute attrs[1];
721+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
722+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
723+
config_part1.numAttrs = 1;
724+
config_part1.attrs = attrs;
725+
726+
cudaLaunchKernelEx(&config_part1, kernel_instance_part1, logits, seqLens, outIndicesAux, stride0, stride1, topK,
727+
next_n, outLogitsAux, 0, nullptr);
680728

681729
constexpr int kNumThreadsPerBlockMerge = 1024;
682-
topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
683-
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(outLogitsAux, seqLens, indices,
684-
multipleBlocksPerRowConfig * topK, 1, topK, next_n, nullptr, multipleBlocksPerRowConfig, outIndicesAux);
730+
auto* kernel_instance_part2 = &topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>;
731+
cudaLaunchConfig_t config_part2;
732+
config_part2.gridDim = numRows;
733+
config_part2.blockDim = kNumThreadsPerBlockMerge;
734+
config_part2.dynamicSmemBytes = topK * sizeof(int32_t);
735+
config_part2.stream = stream;
736+
// Reuse attrs array since part1 kernel has already been launched
737+
config_part2.numAttrs = 1;
738+
config_part2.attrs = attrs;
739+
740+
cudaLaunchKernelEx(&config_part2, kernel_instance_part2, outLogitsAux, seqLens, indices,
741+
multipleBlocksPerRowConfig * topK, 1, topK, next_n, nullptr, multipleBlocksPerRowConfig, outIndicesAux);
685742
}
686743
sync_check_cuda_error(stream);
687744
}

0 commit comments

Comments
 (0)