Skip to content

Commit a36ac45

Browse files
fix: fast redux detection in trtllm gen routing kernel (NVIDIA#5941)
Signed-off-by: Yuan Tong <[email protected]>
1 parent 3dfc819 commit a36ac45

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828

2929
#include <type_traits>
3030

31-
#include "tensorrt_llm/kernels/archCondition.h"
32-
3331
////////////////////////////////////////////////////////////////////////////////////////////////////
3432

3533
namespace moe::dev
@@ -53,10 +51,6 @@ static constexpr int NumWarpsHist = NumThreadsHist / WarpSize;
5351

5452
////////////////////////////////////////////////////////////////////////////////////////////////////
5553

56-
static constexpr bool TLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
57-
58-
////////////////////////////////////////////////////////////////////////////////////////////////////
59-
6054
static __device__ inline float sigmoid_accurate(float x)
6155
{
6256
return 0.5f * tanhf(0.5f * x) + 0.5f;

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include <cooperative_groups/reduce.h>
2121
#include <cub/cub.cuh>
2222

23+
#include "tensorrt_llm/kernels/archCondition.h"
24+
2325
namespace moe::dev::routing
2426
{
2527

@@ -86,12 +88,8 @@ struct TopKRedType
8688

8789
__device__ inline TypeCmp reduce(cg::thread_block_tile<WarpSize> const& warp)
8890
{
89-
#if defined(TLLM_GEN_HAS_FAST_REDUX)
90-
static constexpr bool UseCg = false;
91-
#else
92-
static constexpr bool UseCg = true;
93-
#endif
94-
if constexpr (UseCg || sizeof(TypeCmp) == 8)
91+
static constexpr bool hasFastRedux = tensorrt_llm::kernels::arch::is_major_v<10>;
92+
if constexpr (!hasFastRedux || sizeof(TypeCmp) == 8)
9593
{
9694
return cg::reduce(warp, compVal, cg::greater<TypeCmp>{});
9795
}

0 commit comments

Comments
 (0)