Skip to content

Commit 5873bb3

Browse files
committed
ARR_LENGTH_CONST
1 parent 7b0f471 commit 5873bb3

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -882,9 +882,8 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
882882
return target_location + 1;
883883
}
884884

885-
template <class T>
885+
template <class T, int ARR_LENGTH_CONST>
886886
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) {
887-
constexpr int ARR_LENGTH_CONST = 128;
888887
if (arr_length != ARR_LENGTH_CONST) {
889888
asm("trap;");
890889
}
@@ -910,11 +909,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices
910909
return (int64_t)total;
911910
}
912911

913-
template <class T>
912+
template <class T, int ARR_LENGTH_CONST = 128>
914913
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) {
915914
// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
916915

917-
return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
916+
return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST>(sorted_indices, arr_length, target);
918917

919918
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
920919
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
@@ -1462,7 +1461,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
14621461

14631462
template <class InputActivationsType, class ExpandedActivationsType,
14641463
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1465-
bool PRE_QUANT_AWQ>
1464+
bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128>
14661465
__global__ void expandInputRowsKernel(
14671466
InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output,
14681467
float const* unpermuted_scales, float* permuted_scales,
@@ -1557,7 +1556,7 @@ __global__ void expandInputRowsKernel(
15571556

15581557
if constexpr (is_nvfp4 || is_mxfp8) {
15591558
static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types");
1560-
int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node,
1559+
int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST>(expert_first_token_offset, num_experts_per_node,
15611560
(int64_t)permuted_row + 1) -
15621561
1;
15631562

@@ -2207,7 +2206,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
22072206
// ============================== Activation =================================
22082207

22092208
template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
2210-
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2209+
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128>
22112210
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22122211
float const* fp8_quant, ScaleBiasType const* bias_ptr,
22132212
bool bias_is_broadcast, int64_t const* expert_first_token_offset,
@@ -2270,7 +2269,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22702269
activation_params.swiglu_limit) {
22712270
// TODO this is almost certainly faster as a linear scan
22722271
expert =
2273-
findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) -
2272+
findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST>(expert_first_token_offset, num_experts_per_node, token + 1) -
22742273
1;
22752274
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f;
22762275
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f;

0 commit comments

Comments
 (0)