Skip to content

Commit bb7a97a

Browse files
committed
Revert "ARR_LENGTH_CONST"
This reverts commit 5873bb3.
1 parent 5873bb3 commit bb7a97a

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -882,8 +882,9 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
882882
return target_location + 1;
883883
}
884884

885-
template <class T, int ARR_LENGTH_CONST>
885+
template <class T>
886886
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) {
887+
constexpr int ARR_LENGTH_CONST = 128;
887888
if (arr_length != ARR_LENGTH_CONST) {
888889
asm("trap;");
889890
}
@@ -909,11 +910,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices
909910
return (int64_t)total;
910911
}
911912

912-
template <class T, int ARR_LENGTH_CONST = 128>
913+
template <class T>
913914
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) {
914915
// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
915916

916-
return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST>(sorted_indices, arr_length, target);
917+
return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
917918

918919
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
919920
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
@@ -1461,7 +1462,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
14611462

14621463
template <class InputActivationsType, class ExpandedActivationsType,
14631464
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1464-
bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128>
1465+
bool PRE_QUANT_AWQ>
14651466
__global__ void expandInputRowsKernel(
14661467
InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output,
14671468
float const* unpermuted_scales, float* permuted_scales,
@@ -1556,7 +1557,7 @@ __global__ void expandInputRowsKernel(
15561557

15571558
if constexpr (is_nvfp4 || is_mxfp8) {
15581559
static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types");
1559-
int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST>(expert_first_token_offset, num_experts_per_node,
1560+
int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node,
15601561
(int64_t)permuted_row + 1) -
15611562
1;
15621563

@@ -2206,7 +2207,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
22062207
// ============================== Activation =================================
22072208

22082209
template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
2209-
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128>
2210+
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
22102211
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22112212
float const* fp8_quant, ScaleBiasType const* bias_ptr,
22122213
bool bias_is_broadcast, int64_t const* expert_first_token_offset,
@@ -2269,7 +2270,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22692270
activation_params.swiglu_limit) {
22702271
// TODO this is almost certainly faster as a linear scan
22712272
expert =
2272-
findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST>(expert_first_token_offset, num_experts_per_node, token + 1) -
2273+
findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) -
22732274
1;
22742275
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f;
22752276
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f;

0 commit comments

Comments
 (0)