@@ -882,8 +882,9 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
882
882
return target_location + 1 ;
883
883
}
884
884
885
- template <class T , int ARR_LENGTH_CONST >
885
+ template <class T >
886
886
__device__ inline int64_t findTotalEltsLessThanTarget_v2 (T const * sorted_indices, int64_t const arr_length, T const target) {
887
+ constexpr int ARR_LENGTH_CONST = 128 ;
887
888
if (arr_length != ARR_LENGTH_CONST) {
888
889
asm (" trap;" );
889
890
}
@@ -909,11 +910,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices
909
910
return (int64_t )total;
910
911
}
911
912
912
- template <class T , int ARR_LENGTH_CONST = 128 >
913
+ template <class T >
913
914
__device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices, int64_t const arr_length, T const target) {
914
915
// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
915
916
916
- return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST> (sorted_indices, arr_length, target);
917
+ return findTotalEltsLessThanTarget_v2 (sorted_indices, arr_length, target);
917
918
918
919
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
919
920
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
@@ -1461,7 +1462,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
1461
1462
1462
1463
template <class InputActivationsType , class ExpandedActivationsType ,
1463
1464
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1464
- bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128 >
1465
+ bool PRE_QUANT_AWQ>
1465
1466
__global__ void expandInputRowsKernel (
1466
1467
InputActivationsType const * unpermuted_input, ExpandedActivationsType* permuted_output,
1467
1468
float const * unpermuted_scales, float * permuted_scales,
@@ -1556,7 +1557,7 @@ __global__ void expandInputRowsKernel(
1556
1557
1557
1558
if constexpr (is_nvfp4 || is_mxfp8) {
1558
1559
static_assert (ELEM_PER_THREAD == 8 , " Expecting 8 elements per thread for quantized types" );
1559
- int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node,
1560
+ int64_t expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
1560
1561
(int64_t )permuted_row + 1 ) -
1561
1562
1 ;
1562
1563
@@ -2206,7 +2207,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
2206
2207
// ============================== Activation =================================
2207
2208
2208
2209
template <class T , class GemmOutputType , class ScaleBiasType , class ActFn ,
2209
- TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128 >
2210
+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2210
2211
__global__ void doActivationKernel (T* output, GemmOutputType const * gemm_result,
2211
2212
float const * fp8_quant, ScaleBiasType const * bias_ptr,
2212
2213
bool bias_is_broadcast, int64_t const * expert_first_token_offset,
@@ -2269,7 +2270,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
2269
2270
activation_params.swiglu_limit ) {
2270
2271
// TODO this is almost certainly faster as a linear scan
2271
2272
expert =
2272
- findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2273
+ findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2273
2274
1 ;
2274
2275
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha [expert] : 1 .0f ;
2275
2276
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta [expert] : 0 .0f ;
0 commit comments