@@ -589,6 +589,9 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
589589 int const * rowStarts, int const * rowEnds, int * outIndices, int stride0, int stride1, int const topK,
590590 int const offsetIndex)
591591{
592+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
593+ cudaGridDependencySynchronize ();
594+ #endif
592595 // The number of bins in the histogram.
593596 static constexpr int kNumBins = 2048 ;
594597
@@ -605,13 +608,19 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
605608
606609 topKPerRowJob<kNumThreadsPerBlock , kNumBins , useRadixSort>(
607610 nullptr , logits, rowStart, rowEnd, outIndices, nullptr , stride1, topK);
611+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
612+ cudaTriggerProgrammaticLaunchCompletion ();
613+ #endif
608614}
609615
610616template <int kNumThreadsPerBlock , bool useRadixSort, bool multipleBlocksPerRow = false , bool mergeBlocks = false >
611617static __global__ __launch_bounds__ (kNumThreadsPerBlock ) void topKPerRowDecode(float const * logits, int const * seqLens,
612618 int * outIndices, int stride0, int stride1, int const topK, int next_n, float * outLogits = nullptr ,
613619 int const numBlocksToMerge = 0 , int const * indices = nullptr )
614620{
621+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
622+ cudaGridDependencySynchronize ();
623+ #endif
615624 // The number of bins in the histogram.
616625 static constexpr int kNumBins = 2048 ;
617626
@@ -646,6 +655,9 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(f
646655
647656 topKPerRowJob<kNumThreadsPerBlock , kNumBins , useRadixSort, multipleBlocksPerRow, mergeBlocks>(
648657 indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK);
658+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
659+ cudaTriggerProgrammaticLaunchCompletion ();
660+ #endif
649661}
650662
651663void invokeIndexerTopKDecode (float const * logits, int const * seqLens, int * indices, float * outLogitsAux,
@@ -660,28 +672,73 @@ void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indic
660672 if (numColumns < kSortingAlgorithmThreshold )
661673 {
662674 // Use insertion sort
663- topKPerRowDecode<kNumThreadsPerBlock , false ><<<numRows, kNumThreadsPerBlock , topK * sizeof (int32_t ), stream>>> (
664- logits, seqLens, indices, stride0, stride1, topK, next_n);
675+ auto * kernel_instance = &topKPerRowDecode<kNumThreadsPerBlock , false >;
676+
677+ cudaLaunchConfig_t config;
678+ config.gridDim = numRows;
679+ config.blockDim = kNumThreadsPerBlock ;
680+ config.dynamicSmemBytes = topK * sizeof (int32_t );
681+ config.stream = stream;
682+ cudaLaunchAttribute attrs[1 ];
683+ attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
684+ attrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL ();
685+ config.numAttrs = 1 ;
686+ config.attrs = attrs;
687+
688+ cudaLaunchKernelEx (
689+ &config, kernel_instance, logits, seqLens, indices, stride0, stride1, topK, next_n, nullptr , 0 , nullptr );
665690 }
666691 else if (numColumns < kSplitWorkThreshold )
667692 {
668693 // From this threshold, use radix sort instead
669- topKPerRowDecode<kNumThreadsPerBlock , true ><<<numRows, kNumThreadsPerBlock , topK * sizeof (int32_t ), stream>>> (
670- logits, seqLens, indices, stride0, stride1, topK, next_n);
694+ auto * kernel_instance = &topKPerRowDecode<kNumThreadsPerBlock , true >;
695+
696+ cudaLaunchConfig_t config;
697+ config.gridDim = numRows;
698+ config.blockDim = kNumThreadsPerBlock ;
699+ config.dynamicSmemBytes = topK * sizeof (int32_t );
700+ config.stream = stream;
701+ cudaLaunchAttribute attrs[1 ];
702+ attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
703+ attrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL ();
704+ config.numAttrs = 1 ;
705+ config.attrs = attrs;
706+
707+ cudaLaunchKernelEx (
708+ &config, kernel_instance, logits, seqLens, indices, stride0, stride1, topK, next_n, nullptr , 0 , nullptr );
671709 }
672710 else
673711 {
674712 // Long sequences are run in two steps
675713 constexpr auto multipleBlocksPerRowConfig = 10 ;
676-
677- topKPerRowDecode<kNumThreadsPerBlock , true , true >
678- <<<dim3 (numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock , 2 * topK * sizeof (int32_t ), stream>>> (
679- logits, seqLens, outIndicesAux, stride0, stride1, topK, next_n, outLogitsAux);
714+ auto * kernel_instance_part1 = &topKPerRowDecode<kNumThreadsPerBlock , true , true >;
715+ cudaLaunchConfig_t config_part1;
716+ config_part1.gridDim = dim3 (numRows, multipleBlocksPerRowConfig);
717+ config_part1.blockDim = kNumThreadsPerBlock ;
718+ config_part1.dynamicSmemBytes = 2 * topK * sizeof (int32_t );
719+ config_part1.stream = stream;
720+ cudaLaunchAttribute attrs[1 ];
721+ attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
722+ attrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL ();
723+ config_part1.numAttrs = 1 ;
724+ config_part1.attrs = attrs;
725+
726+ cudaLaunchKernelEx (&config_part1, kernel_instance_part1, logits, seqLens, outIndicesAux, stride0, stride1, topK,
727+ next_n, outLogitsAux, 0 , nullptr );
680728
681729 constexpr int kNumThreadsPerBlockMerge = 1024 ;
682- topKPerRowDecode<kNumThreadsPerBlockMerge , true , false , true >
683- <<<numRows, kNumThreadsPerBlockMerge , topK * sizeof (int32_t ), stream>>> (outLogitsAux, seqLens, indices,
684- multipleBlocksPerRowConfig * topK, 1 , topK, next_n, nullptr , multipleBlocksPerRowConfig, outIndicesAux);
730+ auto * kernel_instance_part2 = &topKPerRowDecode<kNumThreadsPerBlockMerge , true , false , true >;
731+ cudaLaunchConfig_t config_part2;
732+ config_part2.gridDim = numRows;
733+ config_part2.blockDim = kNumThreadsPerBlockMerge ;
734+ config_part2.dynamicSmemBytes = topK * sizeof (int32_t );
735+ config_part2.stream = stream;
736+ // Reuse attrs array since part1 kernel has already been launched
737+ config_part2.numAttrs = 1 ;
738+ config_part2.attrs = attrs;
739+
740+ cudaLaunchKernelEx (&config_part2, kernel_instance_part2, outLogitsAux, seqLens, indices,
741+ multipleBlocksPerRowConfig * topK, 1 , topK, next_n, nullptr , multipleBlocksPerRowConfig, outIndicesAux);
685742 }
686743 sync_check_cuda_error (stream);
687744}
0 commit comments