Skip to content

Commit 7e6c876

Browse files
committed
Revert "hack: findTotalEltsLessThanTarget_v2 support arbitrary arr len"
This reverts commit 8c719e6.
1 parent b9bb8c7 commit 7e6c876

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -884,18 +884,18 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
884884

885885
template <class T>
886886
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) {
887-
// constexpr int ARR_LENGTH_CONST = 128;
888-
// if (arr_length != ARR_LENGTH_CONST) {
889-
// asm("trap;");
890-
// }
887+
constexpr int ARR_LENGTH_CONST = 128;
888+
if (arr_length != ARR_LENGTH_CONST) {
889+
asm("trap;");
890+
}
891891

892892
constexpr unsigned full_mask = 0xffffffffu;
893893
constexpr int WARP_SZ = 32;
894894
const int lane_id = threadIdx.x & (WARP_SZ - 1);
895895

896896
int local_count = 0;
897897
#pragma unroll
898-
for (int k = 0; k < arr_length / WARP_SZ; ++k) {
898+
for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
899899
const int idx = lane_id + k * WARP_SZ;
900900
T v = sorted_indices[idx];
901901
local_count += (v < target) ? 1 : 0;

0 commit comments

Comments
 (0)