diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index 34ec90837b..2bc07cf39a 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -570,6 +570,7 @@ template __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__ void topk_gating_softmax(const T* input, @@ -628,6 +629,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__ // We compute row offset for each thread sub-group const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; const int thread_row = warp_base_row + thread_row_in_warp; + const int thread_row_in_cta = thread_row - cta_base_row; // Threads with indices out of bounds should early exit here. if (thread_row >= num_rows) return; @@ -643,6 +645,9 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__ const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + T weight_sum = static_cast(0); + extern __shared__ T row_output[]; + // Determine the pointer type to use to read in the data depending on the // BYTES_PER_LDG template param. In theory, this can support all powers of 2 // up to 16. @@ -765,7 +770,14 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__ // global memory. (This will be a single) thread per row of the // input/output matrices. const int idx = k * thread_row + k_idx; - output[idx] = T(max_val); + if constexpr (Norm_Weights) { + const int idx_in_cta = k * thread_row_in_cta + k_idx; + row_output[idx_in_cta] = T(max_val); + weight_sum += T(max_val); + } + else { + output[idx] = T(max_val); + } indices[idx] = should_process_row ? expert : NUM_EXPERTS; source_rows[idx] = k_idx * num_rows + thread_row; } @@ -788,6 +800,16 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__ } } } + if constexpr (Norm_Weights) { +#pragma unroll + for (int k_idx = 0; k_idx < k; ++k_idx) { + if (thread_group_idx == 0) { + const int idx = k * thread_row + k_idx; + const int idx_in_cta = k * thread_row_in_cta + k_idx; + output[idx] = row_output[idx_in_cta] / weight_sum; + } + } + } } namespace detail { @@ -807,7 +829,7 @@ struct TopkConstants { }; } // namespace detail -template +template void topk_gating_softmax_launcher_helper(const T* input, T* output, IdxT* indices, @@ -826,8 +848,9 @@ void topk_gating_softmax_launcher_helper(const T* input, const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topk_gating_softmax - <<>>( + static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP; + topk_gating_softmax + <<>>( input, output, num_rows, indices, source_row, k); } diff --git a/custom_ops/gpu_ops/moe/moe_topk_select.cu b/custom_ops/gpu_ops/moe/moe_topk_select.cu index 7217682f45..1798689c0e 100644 --- a/custom_ops/gpu_ops/moe/moe_topk_select.cu +++ b/custom_ops/gpu_ops/moe/moe_topk_select.cu @@ -44,13 +44,19 @@ void moe_topk_select_kernel(const T* input, static constexpr int WARPS_PER_TB = 4; #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ - case N: { \ - topk_gating_softmax_launcher_helper( \ - input, output, indices, source_row, num_rows, num_experts, k, stream); \ - break; \ + case N: { \ + if (apply_norm_weight) { \ + topk_gating_softmax_launcher_helper( \ + input, output, indices, source_row, num_rows, num_experts, k, stream); \ + } else { \ + topk_gating_softmax_launcher_helper( \ + input, output, indices, source_row, num_rows, num_experts, k, stream); \ + } \ + break; \ } int64_t tem_num_experts = num_experts; - if(bias != nullptr || apply_norm_weight) tem_num_experts = 0; + // when bias is not none, set tem_num_experts to 0 to follow the default branch + if(bias != nullptr) tem_num_experts = 0; switch (tem_num_experts) { LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2) LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)