Skip to content

Refactor moe_topk_select op to use apply_norm_weight as a template parameter #3345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 51 additions & 178 deletions custom_ops/gpu_ops/moe/fused_moe_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,64 +150,6 @@ __launch_bounds__(TPB) __global__
}
}

template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
T* output,
IdxT* indices,
int* source_rows,
T* softmax_max_prob,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;

cub_kvp thread_kvp;
cub::ArgMax arg_max;

const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
if (block_row >= num_rows) {
return;
}

const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;

for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities

cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];

for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * block_row + prior_k];

if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}

thread_kvp = arg_max(inp_kvp, thread_kvp);
}

const cub_kvp result_kvp =
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
// restore normalized probes
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}

template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
T* output,
Expand Down Expand Up @@ -262,11 +204,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
}

template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
const T* bias,
__launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_softmax,
T* output,
IdxT* indices,
int* source_rows,
T* softmax_max_prob,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
Expand All @@ -293,7 +235,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
inp_kvp.value = inputs_after_softmax[idx];

for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
Expand All @@ -310,101 +252,17 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
// restore normalized probes
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}

template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
// softmax
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;

__shared__ float normalizing_factor;
__shared__ float float_max;

int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_experts;
const int64_t idx = thread_row_offset+threadIdx.x;

cub::Sum sum;

float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);

const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();

float threadDataSub = threadData - float_max;
float threadDataExp = exp(threadDataSub);

const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);

if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();

T val = T(threadDataExp * normalizing_factor);

// top_k
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP;

cub_kvp thread_kvp;
cub::ArgMax arg_max;

for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities

if (threadIdx.x < num_experts) {
cub_kvp inp_kvp;
int expert = threadIdx.x;
inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val;

for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];

if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}

const cub_kvp result_kvp =
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
indices[cur_idx] = result_kvp.key;
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
}
__syncthreads();
}
}

template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
const T* bias,
T* output,
IdxT* indices,
Expand All @@ -427,10 +285,12 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;
T weight_sum = static_cast<T>(0);
T* row_outputs = nullptr;

extern __shared__ char smem[];

T* row_outputs = reinterpret_cast<T*>(smem);
if constexpr (NormWeights){
extern __shared__ char smem[];
row_outputs = reinterpret_cast<T*>(smem);
}

for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
Expand All @@ -457,28 +317,32 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;

T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
if constexpr (NormWeights){
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
}
else{
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
}
}
__syncthreads();
}
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}

if (threadIdx.x < k) {
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
if constexpr (NormWeights){
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}
if (threadIdx.x < k) {
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
}
}
}


template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
Expand Down Expand Up @@ -532,8 +396,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
cub::ArgMax arg_max;

T weight_sum = static_cast<T>(0);
extern __shared__ char smem[];
T* row_outputs = reinterpret_cast<T*>(smem);
T* row_outputs = nullptr;
if constexpr (NormWeights){
extern __shared__ char smem[];
row_outputs = reinterpret_cast<T*>(smem);
}

for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
Expand All @@ -560,22 +427,28 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;

T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;

indices[cur_idx] = result_kvp.key;
source_rows[cur_idx] = k_idx * num_rows + globalIdx;

if constexpr (NormWeights) {
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
}
else {
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
}
}
__syncthreads();
}
if constexpr (NormWeights) {
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}

if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}

if (threadIdx.x < k) {
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
if (threadIdx.x < k) {
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
}
}
}

Expand Down Expand Up @@ -1015,7 +888,7 @@ static void run(const T* input,
group_experts,
softmax_num_rows);
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB>
group_moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
output,
indices,
Expand Down
4 changes: 2 additions & 2 deletions custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void moe_redundant_topk_select_kernel(const T* input,
else {
assert(k<=TPB);
if (apply_norm_weight) {
moe_softmax_top_k_normed_fused<T, TPB>
moe_softmax_top_k_fused<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
bias,
output,
Expand All @@ -112,7 +112,7 @@ void moe_redundant_topk_select_kernel(const T* input,
k,
num_rows);
} else {
moe_softmax_top_k_fused<T, TPB>
moe_softmax_top_k_fused<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
bias,
output,
Expand Down
8 changes: 4 additions & 4 deletions custom_ops/gpu_ops/moe/moe_topk_select.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void moe_topk_select_kernel(const T* input,
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
if (apply_norm_weight) {
moe_top_k_normed<T, TPB>
moe_top_k<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax,
bias,
output,
Expand All @@ -78,7 +78,7 @@ void moe_topk_select_kernel(const T* input,
k,
num_rows);
} else {
moe_top_k<T, TPB>
moe_top_k<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
bias,
output,
Expand All @@ -93,7 +93,7 @@ void moe_topk_select_kernel(const T* input,
else {
assert(k<=TPB);
if (apply_norm_weight) {
moe_softmax_top_k_normed_fused<T, TPB>
moe_softmax_top_k_fused<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
bias,
output,
Expand All @@ -103,7 +103,7 @@ void moe_topk_select_kernel(const T* input,
k,
num_rows);
} else {
moe_softmax_top_k_fused<T, TPB>
moe_softmax_top_k_fused<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
bias,
output,
Expand Down
Loading
Loading