@@ -882,9 +882,8 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
882
882
return target_location + 1 ;
883
883
}
884
884
885
- template <class T >
885
+ template <class T , int ARR_LENGTH_CONST >
886
886
__device__ inline int64_t findTotalEltsLessThanTarget_v2 (T const * sorted_indices, int64_t const arr_length, T const target) {
887
- constexpr int ARR_LENGTH_CONST = 128 ;
888
887
if (arr_length != ARR_LENGTH_CONST) {
889
888
asm (" trap;" );
890
889
}
@@ -910,11 +909,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices
910
909
return (int64_t )total;
911
910
}
912
911
913
- template <class T >
912
+ template <class T , int ARR_LENGTH_CONST >
914
913
__device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices, int64_t const arr_length, T const target) {
915
914
// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
916
915
917
- return findTotalEltsLessThanTarget_v2 (sorted_indices, arr_length, target);
916
+ return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST> (sorted_indices, arr_length, target);
918
917
919
918
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
920
919
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
@@ -1462,7 +1461,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
1462
1461
1463
1462
template <class InputActivationsType , class ExpandedActivationsType ,
1464
1463
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1465
- bool PRE_QUANT_AWQ>
1464
+ bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128 >
1466
1465
__global__ void expandInputRowsKernel (
1467
1466
InputActivationsType const * unpermuted_input, ExpandedActivationsType* permuted_output,
1468
1467
float const * unpermuted_scales, float * permuted_scales,
@@ -1557,7 +1556,7 @@ __global__ void expandInputRowsKernel(
1557
1556
1558
1557
if constexpr (is_nvfp4 || is_mxfp8) {
1559
1558
static_assert (ELEM_PER_THREAD == 8 , " Expecting 8 elements per thread for quantized types" );
1560
- int64_t expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
1559
+ int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node,
1561
1560
(int64_t )permuted_row + 1 ) -
1562
1561
1 ;
1563
1562
@@ -1735,9 +1734,20 @@ void expandInputRowsKernelLauncher(
1735
1734
TLLM_CHECK_WITH_INFO (quant_params.fp4 .fc1 .weight_block_scale ,
1736
1735
" NVFP4 block scaling is expected for FP4xFP4" );
1737
1736
TLLM_CHECK_WITH_INFO (!prequant_scales, " NVFP4 is not supported for AWQ" );
1738
- return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1737
+ if (num_experts_per_node == 128 ) {
1738
+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
1739
+ return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1739
1740
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1740
- false >;
1741
+ false , NUM_EXPERTS_PER_NODE_CONST>;
1742
+ }
1743
+ if (num_experts_per_node == 64 ) {
1744
+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 64 ;
1745
+ return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1746
+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1747
+ false , NUM_EXPERTS_PER_NODE_CONST>;
1748
+ }
1749
+ printf (" unsupported num_experts_per_node\n " );
1750
+ exit (1 );
1741
1751
} else
1742
1752
#endif
1743
1753
{
@@ -2159,7 +2169,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output,
2159
2169
float gate_bias = 0 .0f ;
2160
2170
float gate_limit = std::numeric_limits<float >::infinity ();
2161
2171
if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit ) {
2162
- int expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
2172
+ int expert = findTotalEltsLessThanTarget< 128 > (expert_first_token_offset, num_experts_per_node,
2163
2173
(int64_t )token + 1 ) -
2164
2174
1 ;
2165
2175
gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha [expert] : 1 .0f ;
@@ -2207,7 +2217,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
2207
2217
// ============================== Activation =================================
2208
2218
2209
2219
template <class T , class GemmOutputType , class ScaleBiasType , class ActFn ,
2210
- TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2220
+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128 >
2211
2221
__global__ void doActivationKernel (T* output, GemmOutputType const * gemm_result,
2212
2222
float const * fp8_quant, ScaleBiasType const * bias_ptr,
2213
2223
bool bias_is_broadcast, int64_t const * expert_first_token_offset,
@@ -2270,7 +2280,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
2270
2280
activation_params.swiglu_limit ) {
2271
2281
// TODO this is almost certainly faster as a linear scan
2272
2282
expert =
2273
- findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2283
+ findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2274
2284
1 ;
2275
2285
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha [expert] : 1 .0f ;
2276
2286
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta [expert] : 0 .0f ;
@@ -2444,30 +2454,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
2444
2454
2445
2455
auto fn = [&]() {
2446
2456
auto fn = [&](auto block_scaling_type) {
2447
- auto fn_list = std::array{
2448
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2449
- IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2450
- decltype (block_scaling_type)::value>, // Gelu
2451
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2452
- IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2453
- decltype (block_scaling_type)::value>, // Relu
2454
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2455
- IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2456
- decltype (block_scaling_type)::value>, // Silu
2457
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2458
- GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2459
- decltype (block_scaling_type)::value>, // Swiglu
2460
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2461
- GLUAdaptor<cutlass::epilogue::thread::GELU>,
2462
- decltype (block_scaling_type)::value>, // Geglu
2463
- &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2464
- decltype (block_scaling_type)::value>, // SwigluBias
2465
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2466
- IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2467
- decltype (block_scaling_type)::value> // Identity
2468
-
2469
- };
2470
- return fn_list[static_cast <int >(activation_type.activation_type )];
2457
+ if (num_experts_per_node == 128 ) {
2458
+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
2459
+ auto fn_list = std::array{
2460
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2461
+ IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2462
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2463
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2464
+ IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2465
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2466
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2467
+ IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2468
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2469
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2470
+ GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2471
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2472
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2473
+ GLUAdaptor<cutlass::epilogue::thread::GELU>,
2474
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2475
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2476
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2477
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2478
+ IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2479
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2480
+
2481
+ };
2482
+ return fn_list[static_cast <int >(activation_type.activation_type )];
2483
+ }
2484
+ if (num_experts_per_node == 64 ) {
2485
+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
2486
+ auto fn_list = std::array{
2487
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2488
+ IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2489
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2490
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2491
+ IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2492
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2493
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2494
+ IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2495
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2496
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2497
+ GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2498
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2499
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2500
+ GLUAdaptor<cutlass::epilogue::thread::GELU>,
2501
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2502
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2503
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2504
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2505
+ IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2506
+ decltype (block_scaling_type)::value>, NUM_EXPERTS_PER_NODE_CONST // Identity
2507
+
2508
+ };
2509
+ return fn_list[static_cast <int >(activation_type.activation_type )];
2510
+ }
2511
+ printf (" unsupported num_experts_per_node\n " );
2512
+ exit (1 );
2471
2513
};
2472
2514
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<
2473
2515
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
0 commit comments