24
24
#include < numeric>
25
25
#include < random>
26
26
#include < sstream>
27
+ #include < type_traits>
27
28
28
29
#include " tensorrt_llm/common/memoryUtils.h"
29
30
#include " tensorrt_llm/common/workspace.h"
@@ -865,7 +866,7 @@ void threeStepBuildExpertMapsSortFirstToken(
865
866
// ============================== Infer GEMM sizes =================================
866
867
// TODO Could linear search be better for small # experts
867
868
template <class T >
868
- __device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices,
869
+ __device__ inline int64_t findTotalEltsLessThanTarget_v1 (T const * sorted_indices,
869
870
int64_t const arr_length, T const target) {
870
871
int64_t low = 0 , high = arr_length - 1 , target_location = -1 ;
871
872
while (low <= high) {
@@ -881,6 +882,48 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
881
882
return target_location + 1 ;
882
883
}
883
884
885
+ template <int ARR_LENGTH_CONST, class T >
886
+ __device__ inline int64_t findTotalEltsLessThanTarget_v2 (T const * sorted_indices, int64_t const arr_length, T const target) {
887
+ if (arr_length != ARR_LENGTH_CONST) {
888
+ asm (" trap;" );
889
+ }
890
+
891
+ constexpr unsigned full_mask = 0xffffffffu ;
892
+ constexpr int WARP_SZ = 32 ;
893
+ const int lane_id = threadIdx .x & (WARP_SZ - 1 );
894
+
895
+ int local_count = 0 ;
896
+ #pragma unroll
897
+ for (int k = 0 ; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
898
+ const int idx = lane_id + k * WARP_SZ;
899
+ T v = sorted_indices[idx];
900
+ local_count += (v < target) ? 1 : 0 ;
901
+ }
902
+
903
+ #pragma unroll
904
+ for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
905
+ local_count += __shfl_down_sync (full_mask, local_count, offset);
906
+ }
907
+ int total = __shfl_sync (full_mask, local_count, 0 );
908
+
909
+ return (int64_t )total;
910
+ }
911
+
912
+ template <int ARR_LENGTH_CONST, class T >
913
+ __device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices, int64_t const arr_length, T const target) {
914
+ // return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
915
+
916
+ return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST>(sorted_indices, arr_length, target);
917
+
918
+ // int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
919
+ // int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
920
+ // if (out_v1 != out_v2) {
921
+ // printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
922
+ // asm("trap;");
923
+ // }
924
+ // return out_v1;
925
+ }
926
+
884
927
template <class T >
885
928
using sizeof_bits = cutlass::sizeof_bits<
886
929
typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t <T>>::type>;
@@ -1418,16 +1461,19 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 256;
1418
1461
1419
1462
template <class InputActivationsType , class ExpandedActivationsType ,
1420
1463
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1421
- bool PRE_QUANT_AWQ>
1464
+ bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128 >
1422
1465
__global__ void expandInputRowsKernel (
1423
1466
InputActivationsType const * unpermuted_input, ExpandedActivationsType* permuted_output,
1424
1467
float const * unpermuted_scales, float * permuted_scales,
1425
- int const * permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size ,
1468
+ int const * permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_ ,
1426
1469
int64_t const k, float const * fc1_act_global_scale, bool use_per_expert_act_scale,
1427
1470
int64_t const * expert_first_token_offset,
1428
1471
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
1429
1472
TmaWarpSpecializedGroupedGemmInput::ElementSF const * input_sf,
1430
1473
int64_t const num_experts_per_node, InputActivationsType const * prequant_scales = nullptr ) {
1474
+ constexpr int hidden_size = 7168 ;
1475
+ if (hidden_size != hidden_size_real_) { asm (" trap;" ); }
1476
+
1431
1477
static_assert (BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ||
1432
1478
!PRE_QUANT_AWQ,
1433
1479
" AWQ and Block Scaling are mutually exclusive" );
@@ -1503,14 +1549,14 @@ __global__ void expandInputRowsKernel(
1503
1549
permuted_row * hidden_size / ELEM_PER_THREAD;
1504
1550
1505
1551
int64_t const start_offset = threadIdx .x ;
1506
- int64_t const stride = EXPAND_THREADS_PER_BLOCK;
1507
- int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1552
+ constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK;
1553
+ constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1508
1554
assert (hidden_size % ELEM_PER_THREAD == 0 );
1509
1555
assert (hidden_size % VecSize == 0 );
1510
1556
1511
1557
if constexpr (is_nvfp4 || is_mxfp8) {
1512
1558
static_assert (ELEM_PER_THREAD == 8 , " Expecting 8 elements per thread for quantized types" );
1513
- int64_t expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
1559
+ int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node,
1514
1560
(int64_t )permuted_row + 1 ) -
1515
1561
1 ;
1516
1562
@@ -1519,6 +1565,7 @@ __global__ void expandInputRowsKernel(
1519
1565
float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1 .0f ;
1520
1566
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
1521
1567
1568
+ #pragma unroll
1522
1569
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
1523
1570
auto in_vec = source_row_ptr[elem_index];
1524
1571
if constexpr (need_nvfp4_quant || need_mxfp8_quant) {
@@ -1687,9 +1734,20 @@ void expandInputRowsKernelLauncher(
1687
1734
TLLM_CHECK_WITH_INFO (quant_params.fp4 .fc1 .weight_block_scale ,
1688
1735
" NVFP4 block scaling is expected for FP4xFP4" );
1689
1736
TLLM_CHECK_WITH_INFO (!prequant_scales, " NVFP4 is not supported for AWQ" );
1690
- return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1737
+ if (num_experts_per_node == 128 ) {
1738
+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
1739
+ return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1691
1740
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1692
- false >;
1741
+ false , NUM_EXPERTS_PER_NODE_CONST>;
1742
+ }
1743
+ if (num_experts_per_node == 64 ) {
1744
+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 64 ;
1745
+ return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1746
+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1747
+ false , NUM_EXPERTS_PER_NODE_CONST>;
1748
+ }
1749
+ printf (" unsupported num_experts_per_node\n " );
1750
+ exit (1 );
1693
1751
} else
1694
1752
#endif
1695
1753
{
@@ -1748,11 +1806,16 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
1748
1806
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
1749
1807
// connection.
1750
1808
template <typename OutputType, class GemmOutputType , class ScaleBiasType , ScaleMode SCALE_MODE>
1751
- __global__ void finalizeMoeRoutingKernel (
1809
+ __global__
1810
+ __maxnreg__ (64 )
1811
+ void finalizeMoeRoutingKernel (
1752
1812
GemmOutputType const * expanded_permuted_rows, OutputType* reduced_unpermuted_output,
1753
1813
ScaleBiasType const * bias, float const * scales, int const * unpermuted_row_to_permuted_row,
1754
- int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token ,
1814
+ int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_ ,
1755
1815
int const num_experts_per_node, int const start_expert_id) {
1816
+ constexpr int experts_per_token = 8 ;
1817
+ if (experts_per_token != experts_per_token_real_) { asm (" trap;" ); }
1818
+
1756
1819
int64_t const original_row = blockIdx .x ;
1757
1820
int64_t const num_rows = gridDim .x ;
1758
1821
auto const offset = original_row * orig_cols;
@@ -2078,7 +2141,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output,
2078
2141
float gate_bias = 0 .0f ;
2079
2142
float gate_limit = std::numeric_limits<float >::infinity ();
2080
2143
if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit ) {
2081
- int expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
2144
+ int expert = findTotalEltsLessThanTarget< 128 > (expert_first_token_offset, num_experts_per_node,
2082
2145
(int64_t )token + 1 ) -
2083
2146
1 ;
2084
2147
gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha [expert] : 1 .0f ;
@@ -2126,14 +2189,17 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
2126
2189
// ============================== Activation =================================
2127
2190
2128
2191
template <class T , class GemmOutputType , class ScaleBiasType , class ActFn ,
2129
- TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2192
+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128 >
2130
2193
__global__ void doActivationKernel (T* output, GemmOutputType const * gemm_result,
2131
2194
float const * fp8_quant, ScaleBiasType const * bias_ptr,
2132
2195
bool bias_is_broadcast, int64_t const * expert_first_token_offset,
2133
- int num_experts_per_node, int64_t inter_size ,
2196
+ int num_experts_per_node, int64_t inter_size_real_ ,
2134
2197
float const * fc2_act_global_scale, bool use_per_expert_act_scale,
2135
2198
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat,
2136
2199
ActivationParams activation_params) {
2200
+ constexpr int inter_size = 2048 ;
2201
+ if (inter_size != inter_size_real_) { asm (" trap;" ); }
2202
+
2137
2203
#ifdef ENABLE_FP4
2138
2204
constexpr bool IsNVFP4 =
2139
2205
std::is_same_v<T, __nv_fp4_e2m1> &&
@@ -2186,7 +2252,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
2186
2252
activation_params.swiglu_limit ) {
2187
2253
// TODO this is almost certainly faster as a linear scan
2188
2254
expert =
2189
- findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2255
+ findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2190
2256
1 ;
2191
2257
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha [expert] : 1 .0f ;
2192
2258
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta [expert] : 0 .0f ;
@@ -2218,16 +2284,18 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
2218
2284
auto output_vec = reinterpret_cast <OutputElem*>(safe_inc_ptr (output, output_offset));
2219
2285
auto bias_ptr_vec = reinterpret_cast <BiasElem const *>(bias_ptr + bias_offset);
2220
2286
int64_t const start_offset = tid;
2221
- int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
2287
+ constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK;
2222
2288
assert (inter_size % ACTIVATION_ELEM_PER_THREAD == 0 );
2223
- int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2289
+ constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2224
2290
assert (gated_off % ACTIVATION_ELEM_PER_THREAD == 0 );
2225
2291
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
2226
2292
2227
2293
ActFn fn{};
2228
2294
fn.alpha = gate_alpha;
2229
2295
fn.beta = gate_beta;
2230
2296
fn.limit = gate_limit;
2297
+
2298
+ #pragma unroll
2231
2299
for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
2232
2300
auto fc1_value =
2233
2301
arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);
@@ -2358,30 +2426,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
2358
2426
2359
2427
auto fn = [&]() {
2360
2428
auto fn = [&](auto block_scaling_type) {
2361
- auto fn_list = std::array{
2362
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2363
- IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2364
- decltype (block_scaling_type)::value>, // Gelu
2365
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2366
- IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2367
- decltype (block_scaling_type)::value>, // Relu
2368
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2369
- IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2370
- decltype (block_scaling_type)::value>, // Silu
2371
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2372
- GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2373
- decltype (block_scaling_type)::value>, // Swiglu
2374
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2375
- GLUAdaptor<cutlass::epilogue::thread::GELU>,
2376
- decltype (block_scaling_type)::value>, // Geglu
2377
- &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2378
- decltype (block_scaling_type)::value>, // SwigluBias
2379
- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2380
- IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2381
- decltype (block_scaling_type)::value> // Identity
2382
-
2383
- };
2384
- return fn_list[static_cast <int >(activation_type.activation_type )];
2429
+ if (num_experts_per_node == 128 ) {
2430
+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
2431
+ auto fn_list = std::array{
2432
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2433
+ IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2434
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2435
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2436
+ IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2437
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2438
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2439
+ IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2440
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2441
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2442
+ GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2443
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2444
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2445
+ GLUAdaptor<cutlass::epilogue::thread::GELU>,
2446
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2447
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2448
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2449
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2450
+ IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2451
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2452
+
2453
+ };
2454
+ return fn_list[static_cast <int >(activation_type.activation_type )];
2455
+ }
2456
+ if (num_experts_per_node == 64 ) {
2457
+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 64 ;
2458
+ auto fn_list = std::array{
2459
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2460
+ IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2461
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2462
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2463
+ IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2464
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2465
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2466
+ IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2467
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2468
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2469
+ GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2470
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2471
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2472
+ GLUAdaptor<cutlass::epilogue::thread::GELU>,
2473
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2474
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2475
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2476
+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2477
+ IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2478
+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2479
+
2480
+ };
2481
+ return fn_list[static_cast <int >(activation_type.activation_type )];
2482
+ }
2483
+ printf (" unsupported num_experts_per_node\n " );
2484
+ exit (1 );
2385
2485
};
2386
2486
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<
2387
2487
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
0 commit comments