Skip to content

[Optimize]Add norm_weights feature for topk_gating_softmax #3372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions custom_ops/gpu_ops/moe/fused_moe_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ template <typename T,
int NUM_EXPERTS,
int WARPS_PER_CTA,
int BYTES_PER_LDG,
bool Norm_Weights = false,
typename IdxT = int>
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
void topk_gating_softmax(const T* input,
Expand Down Expand Up @@ -628,6 +629,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
const int thread_row_in_cta = thread_row - cta_base_row;

// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows) return;
Expand All @@ -643,6 +645,9 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;

T weight_sum = static_cast<T>(0);
extern __shared__ T row_output[];

// Determine the pointer type to use to read in the data depending on the
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
// up to 16.
Expand Down Expand Up @@ -765,7 +770,14 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
// global memory. (This will be a single) thread per row of the
// input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = T(max_val);
if constexpr (Norm_Weights) {
const int idx_in_cta = k * thread_row_in_cta + k_idx;
row_output[idx_in_cta] = T(max_val);
weight_sum += T(max_val);
}
else {
output[idx] = T(max_val);
}
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
}
Expand All @@ -788,6 +800,16 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
}
}
}
if constexpr (Norm_Weights) {
#pragma unroll
for (int k_idx = 0; k_idx < k; ++k_idx) {
if (thread_group_idx == 0) {
const int idx = k * thread_row + k_idx;
const int idx_in_cta = k * thread_row_in_cta + k_idx;
output[idx] = row_output[idx_in_cta] / weight_sum;
}
}
}
}

namespace detail {
Expand All @@ -807,7 +829,7 @@ struct TopkConstants {
};
} // namespace detail

template <typename T, int EXPERTS, int WARPS_PER_TB, typename IdxT = int>
template <typename T, int EXPERTS, int WARPS_PER_TB, bool Norm_Weights = false, typename IdxT = int>
void topk_gating_softmax_launcher_helper(const T* input,
T* output,
IdxT* indices,
Expand All @@ -826,8 +848,9 @@ void topk_gating_softmax_launcher_helper(const T* input,
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;

dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
<<<num_blocks, block_dim, 0, stream>>>(
static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, Norm_Weights>
<<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof(T), stream>>>(
input, output, num_rows, indices, source_row, k);
}

Expand Down
16 changes: 11 additions & 5 deletions custom_ops/gpu_ops/moe/moe_topk_select.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ void moe_topk_select_kernel(const T* input,
static constexpr int WARPS_PER_TB = 4;

#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
case N: { \
if (apply_norm_weight) { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
} else { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
} \
break; \
}
int64_t tem_num_experts = num_experts;
if(bias != nullptr || apply_norm_weight) tem_num_experts = 0;
// when bias is not none, set tem_num_experts to 0 to follow the default branch
if(bias != nullptr) tem_num_experts = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加一个注释,说明下,否则怪怪的,意思就是说当bias不空的时候,只走default逻辑

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

switch (tem_num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
Expand Down
Loading