Skip to content

Commit 2e78311

Browse files
authored
[Optimize]Add norm_weights feature for topk_gating_softmax (#3372)
1 parent 666ab65 commit 2e78311

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

custom_ops/gpu_ops/moe/fused_moe_op.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ template <typename T,
570570
int NUM_EXPERTS,
571571
int WARPS_PER_CTA,
572572
int BYTES_PER_LDG,
573+
bool Norm_Weights = false,
573574
typename IdxT = int>
574575
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
575576
void topk_gating_softmax(const T* input,
@@ -628,6 +629,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
628629
// We compute row offset for each thread sub-group
629630
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
630631
const int thread_row = warp_base_row + thread_row_in_warp;
632+
const int thread_row_in_cta = thread_row - cta_base_row;
631633

632634
// Threads with indices out of bounds should early exit here.
633635
if (thread_row >= num_rows) return;
@@ -643,6 +645,9 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
643645
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
644646
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
645647

648+
T weight_sum = static_cast<T>(0);
649+
extern __shared__ T row_output[];
650+
646651
// Determine the pointer type to use to read in the data depending on the
647652
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
648653
// up to 16.
@@ -765,7 +770,14 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
765770
// global memory. (This will be a single) thread per row of the
766771
// input/output matrices.
767772
const int idx = k * thread_row + k_idx;
768-
output[idx] = T(max_val);
773+
if constexpr (Norm_Weights) {
774+
const int idx_in_cta = k * thread_row_in_cta + k_idx;
775+
row_output[idx_in_cta] = T(max_val);
776+
weight_sum += T(max_val);
777+
}
778+
else {
779+
output[idx] = T(max_val);
780+
}
769781
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
770782
source_rows[idx] = k_idx * num_rows + thread_row;
771783
}
@@ -788,6 +800,16 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
788800
}
789801
}
790802
}
803+
if constexpr (Norm_Weights) {
804+
#pragma unroll
805+
for (int k_idx = 0; k_idx < k; ++k_idx) {
806+
if (thread_group_idx == 0) {
807+
const int idx = k * thread_row + k_idx;
808+
const int idx_in_cta = k * thread_row_in_cta + k_idx;
809+
output[idx] = row_output[idx_in_cta] / weight_sum;
810+
}
811+
}
812+
}
791813
}
792814

793815
namespace detail {
@@ -807,7 +829,7 @@ struct TopkConstants {
807829
};
808830
} // namespace detail
809831

810-
template <typename T, int EXPERTS, int WARPS_PER_TB, typename IdxT = int>
832+
template <typename T, int EXPERTS, int WARPS_PER_TB, bool Norm_Weights = false, typename IdxT = int>
811833
void topk_gating_softmax_launcher_helper(const T* input,
812834
T* output,
813835
IdxT* indices,
@@ -826,8 +848,9 @@ void topk_gating_softmax_launcher_helper(const T* input,
826848
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
827849

828850
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
829-
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
830-
<<<num_blocks, block_dim, 0, stream>>>(
851+
static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
852+
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, Norm_Weights>
853+
<<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof(T), stream>>>(
831854
input, output, num_rows, indices, source_row, k);
832855
}
833856

custom_ops/gpu_ops/moe/moe_topk_select.cu

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,19 @@ void moe_topk_select_kernel(const T* input,
4444
static constexpr int WARPS_PER_TB = 4;
4545

4646
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
47-
case N: { \
48-
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
49-
input, output, indices, source_row, num_rows, num_experts, k, stream); \
50-
break; \
47+
case N: { \
48+
if (apply_norm_weight) { \
49+
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
50+
input, output, indices, source_row, num_rows, num_experts, k, stream); \
51+
} else { \
52+
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
53+
input, output, indices, source_row, num_rows, num_experts, k, stream); \
54+
} \
55+
break; \
5156
}
5257
int64_t tem_num_experts = num_experts;
53-
if(bias != nullptr || apply_norm_weight) tem_num_experts = 0;
58+
// when bias is not none, set tem_num_experts to 0 to follow the default branch
59+
if(bias != nullptr) tem_num_experts = 0;
5460
switch (tem_num_experts) {
5561
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
5662
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)

0 commit comments

Comments
 (0)