Skip to content

Commit 3a74536

Browse files
committed
partial cp
1 parent 3704820 commit 3a74536

File tree

1 file changed

+140
-40
lines changed

1 file changed

+140
-40
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 140 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <numeric>
2525
#include <random>
2626
#include <sstream>
27+
#include <type_traits>
2728

2829
#include "tensorrt_llm/common/memoryUtils.h"
2930
#include "tensorrt_llm/common/workspace.h"
@@ -865,7 +866,7 @@ void threeStepBuildExpertMapsSortFirstToken(
865866
// ============================== Infer GEMM sizes =================================
866867
// TODO Could linear search be better for small # experts
867868
template <class T>
868-
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
869+
__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices,
869870
int64_t const arr_length, T const target) {
870871
int64_t low = 0, high = arr_length - 1, target_location = -1;
871872
while (low <= high) {
@@ -881,6 +882,48 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
881882
return target_location + 1;
882883
}
883884

885+
template <int ARR_LENGTH_CONST, class T>
886+
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) {
887+
if (arr_length != ARR_LENGTH_CONST) {
888+
asm("trap;");
889+
}
890+
891+
constexpr unsigned full_mask = 0xffffffffu;
892+
constexpr int WARP_SZ = 32;
893+
const int lane_id = threadIdx.x & (WARP_SZ - 1);
894+
895+
int local_count = 0;
896+
#pragma unroll
897+
for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
898+
const int idx = lane_id + k * WARP_SZ;
899+
T v = sorted_indices[idx];
900+
local_count += (v < target) ? 1 : 0;
901+
}
902+
903+
#pragma unroll
904+
for (int offset = 16; offset > 0; offset >>= 1) {
905+
local_count += __shfl_down_sync(full_mask, local_count, offset);
906+
}
907+
int total = __shfl_sync(full_mask, local_count, 0);
908+
909+
return (int64_t)total;
910+
}
911+
912+
template <int ARR_LENGTH_CONST, class T>
913+
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) {
914+
// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
915+
916+
return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST>(sorted_indices, arr_length, target);
917+
918+
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
919+
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
920+
// if (out_v1 != out_v2) {
921+
// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
922+
// asm("trap;");
923+
// }
924+
// return out_v1;
925+
}
926+
884927
template <class T>
885928
using sizeof_bits = cutlass::sizeof_bits<
886929
typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
@@ -1418,16 +1461,19 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 256;
14181461

14191462
template <class InputActivationsType, class ExpandedActivationsType,
14201463
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1421-
bool PRE_QUANT_AWQ>
1464+
bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128>
14221465
__global__ void expandInputRowsKernel(
14231466
InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output,
14241467
float const* unpermuted_scales, float* permuted_scales,
1425-
int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size,
1468+
int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_,
14261469
int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale,
14271470
int64_t const* expert_first_token_offset,
14281471
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
14291472
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf,
14301473
int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) {
1474+
constexpr int hidden_size = 7168;
1475+
if (hidden_size != hidden_size_real_) { asm("trap;"); }
1476+
14311477
static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ||
14321478
!PRE_QUANT_AWQ,
14331479
"AWQ and Block Scaling are mutually exclusive");
@@ -1503,14 +1549,14 @@ __global__ void expandInputRowsKernel(
15031549
permuted_row * hidden_size / ELEM_PER_THREAD;
15041550

15051551
int64_t const start_offset = threadIdx.x;
1506-
int64_t const stride = EXPAND_THREADS_PER_BLOCK;
1507-
int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1552+
constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK;
1553+
constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD;
15081554
assert(hidden_size % ELEM_PER_THREAD == 0);
15091555
assert(hidden_size % VecSize == 0);
15101556

15111557
if constexpr (is_nvfp4 || is_mxfp8) {
15121558
static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types");
1513-
int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node,
1559+
int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST>(expert_first_token_offset, num_experts_per_node,
15141560
(int64_t)permuted_row + 1) -
15151561
1;
15161562

@@ -1519,6 +1565,7 @@ __global__ void expandInputRowsKernel(
15191565
float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f;
15201566
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
15211567

1568+
#pragma unroll
15221569
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
15231570
auto in_vec = source_row_ptr[elem_index];
15241571
if constexpr (need_nvfp4_quant || need_mxfp8_quant) {
@@ -1687,9 +1734,20 @@ void expandInputRowsKernelLauncher(
16871734
TLLM_CHECK_WITH_INFO(quant_params.fp4.fc1.weight_block_scale,
16881735
"NVFP4 block scaling is expected for FP4xFP4");
16891736
TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ");
1690-
return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1737+
if (num_experts_per_node == 128) {
1738+
constexpr int NUM_EXPERTS_PER_NODE_CONST = 128;
1739+
return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
16911740
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1692-
false>;
1741+
false, NUM_EXPERTS_PER_NODE_CONST>;
1742+
}
1743+
if (num_experts_per_node == 64) {
1744+
constexpr int NUM_EXPERTS_PER_NODE_CONST = 64;
1745+
return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1746+
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1747+
false, NUM_EXPERTS_PER_NODE_CONST>;
1748+
}
1749+
printf("unsupported num_experts_per_node\n");
1750+
exit(1);
16931751
} else
16941752
#endif
16951753
{
@@ -1748,11 +1806,16 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
17481806
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
17491807
// connection.
17501808
template <typename OutputType, class GemmOutputType, class ScaleBiasType, ScaleMode SCALE_MODE>
1751-
__global__ void finalizeMoeRoutingKernel(
1809+
__global__
1810+
__maxnreg__(64)
1811+
void finalizeMoeRoutingKernel(
17521812
GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
17531813
ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row,
1754-
int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token,
1814+
int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_,
17551815
int const num_experts_per_node, int const start_expert_id) {
1816+
constexpr int experts_per_token = 8;
1817+
if (experts_per_token != experts_per_token_real_) { asm("trap;"); }
1818+
17561819
int64_t const original_row = blockIdx.x;
17571820
int64_t const num_rows = gridDim.x;
17581821
auto const offset = original_row * orig_cols;
@@ -2078,7 +2141,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output,
20782141
float gate_bias = 0.0f;
20792142
float gate_limit = std::numeric_limits<float>::infinity();
20802143
if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit) {
2081-
int expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node,
2144+
int expert = findTotalEltsLessThanTarget<128>(expert_first_token_offset, num_experts_per_node,
20822145
(int64_t)token + 1) -
20832146
1;
20842147
gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha[expert] : 1.0f;
@@ -2126,14 +2189,17 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
21262189
// ============================== Activation =================================
21272190

21282191
template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
2129-
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2192+
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128>
21302193
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
21312194
float const* fp8_quant, ScaleBiasType const* bias_ptr,
21322195
bool bias_is_broadcast, int64_t const* expert_first_token_offset,
2133-
int num_experts_per_node, int64_t inter_size,
2196+
int num_experts_per_node, int64_t inter_size_real_,
21342197
float const* fc2_act_global_scale, bool use_per_expert_act_scale,
21352198
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat,
21362199
ActivationParams activation_params) {
2200+
constexpr int inter_size = 2048;
2201+
if (inter_size != inter_size_real_) { asm("trap;"); }
2202+
21372203
#ifdef ENABLE_FP4
21382204
constexpr bool IsNVFP4 =
21392205
std::is_same_v<T, __nv_fp4_e2m1> &&
@@ -2186,7 +2252,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
21862252
activation_params.swiglu_limit) {
21872253
// TODO this is almost certainly faster as a linear scan
21882254
expert =
2189-
findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) -
2255+
findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST>(expert_first_token_offset, num_experts_per_node, token + 1) -
21902256
1;
21912257
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f;
21922258
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f;
@@ -2218,16 +2284,18 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22182284
auto output_vec = reinterpret_cast<OutputElem*>(safe_inc_ptr(output, output_offset));
22192285
auto bias_ptr_vec = reinterpret_cast<BiasElem const*>(bias_ptr + bias_offset);
22202286
int64_t const start_offset = tid;
2221-
int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
2287+
constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK;
22222288
assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0);
2223-
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2289+
constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
22242290
assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0);
22252291
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
22262292

22272293
ActFn fn{};
22282294
fn.alpha = gate_alpha;
22292295
fn.beta = gate_beta;
22302296
fn.limit = gate_limit;
2297+
2298+
#pragma unroll
22312299
for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
22322300
auto fc1_value =
22332301
arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);
@@ -2358,30 +2426,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23582426

23592427
auto fn = [&]() {
23602428
auto fn = [&](auto block_scaling_type) {
2361-
auto fn_list = std::array{
2362-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2363-
IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2364-
decltype(block_scaling_type)::value>, // Gelu
2365-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2366-
IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2367-
decltype(block_scaling_type)::value>, // Relu
2368-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2369-
IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2370-
decltype(block_scaling_type)::value>, // Silu
2371-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2372-
GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2373-
decltype(block_scaling_type)::value>, // Swiglu
2374-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2375-
GLUAdaptor<cutlass::epilogue::thread::GELU>,
2376-
decltype(block_scaling_type)::value>, // Geglu
2377-
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2378-
decltype(block_scaling_type)::value>, // SwigluBias
2379-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2380-
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2381-
decltype(block_scaling_type)::value> // Identity
2382-
2383-
};
2384-
return fn_list[static_cast<int>(activation_type.activation_type)];
2429+
if (num_experts_per_node == 128) {
2430+
constexpr int NUM_EXPERTS_PER_NODE_CONST = 128;
2431+
auto fn_list = std::array{
2432+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2433+
IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2434+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2435+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2436+
IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2437+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2438+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2439+
IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2440+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2441+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2442+
GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2443+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2444+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2445+
GLUAdaptor<cutlass::epilogue::thread::GELU>,
2446+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2447+
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2448+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2449+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2450+
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2451+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2452+
2453+
};
2454+
return fn_list[static_cast<int>(activation_type.activation_type)];
2455+
}
2456+
if (num_experts_per_node == 64) {
2457+
constexpr int NUM_EXPERTS_PER_NODE_CONST = 64;
2458+
auto fn_list = std::array{
2459+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2460+
IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2461+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2462+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2463+
IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2464+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2465+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2466+
IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2467+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2468+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2469+
GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2470+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2471+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2472+
GLUAdaptor<cutlass::epilogue::thread::GELU>,
2473+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2474+
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2475+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2476+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2477+
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2478+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2479+
2480+
};
2481+
return fn_list[static_cast<int>(activation_type.activation_type)];
2482+
}
2483+
printf("unsupported num_experts_per_node\n");
2484+
exit(1);
23852485
};
23862486
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<
23872487
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,

0 commit comments

Comments
 (0)