Skip to content

Commit b204a15

Browse files
committed
hack NUM_EXPERTS_PER_NODE_CONST
1 parent 7e6c876 commit b204a15

File tree

1 file changed

+77
-35
lines changed

1 file changed

+77
-35
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -882,9 +882,8 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
882882
return target_location + 1;
883883
}
884884

885-
template <class T>
885+
template <class T, int ARR_LENGTH_CONST>
886886
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) {
887-
constexpr int ARR_LENGTH_CONST = 128;
888887
if (arr_length != ARR_LENGTH_CONST) {
889888
asm("trap;");
890889
}
@@ -910,11 +909,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices
910909
return (int64_t)total;
911910
}
912911

913-
template <class T>
912+
template <class T, int ARR_LENGTH_CONST>
914913
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) {
915914
// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
916915

917-
return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
916+
return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST>(sorted_indices, arr_length, target);
918917

919918
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
920919
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
@@ -1462,7 +1461,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
14621461

14631462
template <class InputActivationsType, class ExpandedActivationsType,
14641463
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1465-
bool PRE_QUANT_AWQ>
1464+
bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128>
14661465
__global__ void expandInputRowsKernel(
14671466
InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output,
14681467
float const* unpermuted_scales, float* permuted_scales,
@@ -1557,7 +1556,7 @@ __global__ void expandInputRowsKernel(
15571556

15581557
if constexpr (is_nvfp4 || is_mxfp8) {
15591558
static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types");
1560-
int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node,
1559+
int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST>(expert_first_token_offset, num_experts_per_node,
15611560
(int64_t)permuted_row + 1) -
15621561
1;
15631562

@@ -1735,9 +1734,20 @@ void expandInputRowsKernelLauncher(
17351734
TLLM_CHECK_WITH_INFO(quant_params.fp4.fc1.weight_block_scale,
17361735
"NVFP4 block scaling is expected for FP4xFP4");
17371736
TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ");
1738-
return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1737+
if (num_experts_per_node == 128) {
1738+
constexpr int NUM_EXPERTS_PER_NODE_CONST = 128;
1739+
return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
17391740
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1740-
false>;
1741+
false, NUM_EXPERTS_PER_NODE_CONST>;
1742+
}
1743+
if (num_experts_per_node == 64) {
1744+
constexpr int NUM_EXPERTS_PER_NODE_CONST = 64;
1745+
return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1746+
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1747+
false, NUM_EXPERTS_PER_NODE_CONST>;
1748+
}
1749+
printf("unsupported num_experts_per_node\n");
1750+
exit(1);
17411751
} else
17421752
#endif
17431753
{
@@ -2159,7 +2169,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output,
21592169
float gate_bias = 0.0f;
21602170
float gate_limit = std::numeric_limits<float>::infinity();
21612171
if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit) {
2162-
int expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node,
2172+
int expert = findTotalEltsLessThanTarget<128>(expert_first_token_offset, num_experts_per_node,
21632173
(int64_t)token + 1) -
21642174
1;
21652175
gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha[expert] : 1.0f;
@@ -2207,7 +2217,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
22072217
// ============================== Activation =================================
22082218

22092219
template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
2210-
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2220+
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128>
22112221
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22122222
float const* fp8_quant, ScaleBiasType const* bias_ptr,
22132223
bool bias_is_broadcast, int64_t const* expert_first_token_offset,
@@ -2270,7 +2280,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22702280
activation_params.swiglu_limit) {
22712281
// TODO this is almost certainly faster as a linear scan
22722282
expert =
2273-
findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) -
2283+
findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST>(expert_first_token_offset, num_experts_per_node, token + 1) -
22742284
1;
22752285
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f;
22762286
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f;
@@ -2444,30 +2454,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
24442454

24452455
auto fn = [&]() {
24462456
auto fn = [&](auto block_scaling_type) {
2447-
auto fn_list = std::array{
2448-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2449-
IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2450-
decltype(block_scaling_type)::value>, // Gelu
2451-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2452-
IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2453-
decltype(block_scaling_type)::value>, // Relu
2454-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2455-
IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2456-
decltype(block_scaling_type)::value>, // Silu
2457-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2458-
GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2459-
decltype(block_scaling_type)::value>, // Swiglu
2460-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2461-
GLUAdaptor<cutlass::epilogue::thread::GELU>,
2462-
decltype(block_scaling_type)::value>, // Geglu
2463-
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2464-
decltype(block_scaling_type)::value>, // SwigluBias
2465-
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2466-
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2467-
decltype(block_scaling_type)::value> // Identity
2468-
2469-
};
2470-
return fn_list[static_cast<int>(activation_type.activation_type)];
2457+
if (num_experts_per_node == 128) {
2458+
constexpr int NUM_EXPERTS_PER_NODE_CONST = 128;
2459+
auto fn_list = std::array{
2460+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2461+
IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2462+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2463+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2464+
IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2465+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2466+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2467+
IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2468+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2469+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2470+
GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2471+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2472+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2473+
GLUAdaptor<cutlass::epilogue::thread::GELU>,
2474+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2475+
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2476+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2477+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2478+
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2479+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2480+
2481+
};
2482+
return fn_list[static_cast<int>(activation_type.activation_type)];
2483+
}
2484+
if (num_experts_per_node == 64) {
2485+
constexpr int NUM_EXPERTS_PER_NODE_CONST = 128;
2486+
auto fn_list = std::array{
2487+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2488+
IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2489+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2490+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2491+
IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2492+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2493+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2494+
IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2495+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2496+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2497+
GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2498+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2499+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2500+
GLUAdaptor<cutlass::epilogue::thread::GELU>,
2501+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2502+
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2503+
decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2504+
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
2505+
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2506+
decltype(block_scaling_type)::value>, NUM_EXPERTS_PER_NODE_CONST // Identity
2507+
2508+
};
2509+
return fn_list[static_cast<int>(activation_type.activation_type)];
2510+
}
2511+
printf("unsupported num_experts_per_node\n");
2512+
exit(1);
24712513
};
24722514
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<
24732515
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,

0 commit comments

Comments
 (0)