diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index 09d705d410..34ec90837b 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -150,64 +150,6 @@ __launch_bounds__(TPB) __global__ } } -template -__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, - T* output, - IdxT* indices, - int* source_rows, - T* softmax_max_prob, - const int64_t num_experts, - const int64_t k, - const int64_t num_rows) { - using cub_kvp = cub::KeyValuePair; - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; - - cub_kvp thread_kvp; - cub::ArgMax arg_max; - - const int block_row = blockIdx.x + blockIdx.y * gridDim.x; - if (block_row >= num_rows) { - return; - } - - const bool should_process_row = true; - const int thread_read_offset = block_row * num_experts; - - for (int k_idx = 0; k_idx < k; ++k_idx) { - thread_kvp.key = 0; - thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities - - cub_kvp inp_kvp; - for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { - const int idx = thread_read_offset + expert; - inp_kvp.key = expert; - inp_kvp.value = inputs_after_softmax[idx]; - - for (int prior_k = 0; prior_k < k_idx; ++prior_k) { - const IdxT prior_winning_expert = indices[k * block_row + prior_k]; - - if (prior_winning_expert == expert) { - inp_kvp = thread_kvp; - } - } - - thread_kvp = arg_max(inp_kvp, thread_kvp); - } - - const cub_kvp result_kvp = - BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); - if (threadIdx.x == 0) { - const int idx = k * block_row + k_idx; - // restore normalized probes - output[idx] = result_kvp.value / T(softmax_max_prob[idx]); - indices[idx] = should_process_row ? result_kvp.key : num_experts; - source_rows[idx] = k_idx * num_rows + block_row; - } - __syncthreads(); - } -} - template __launch_bounds__(TPB) __global__ void moe_softmax(const T* input, T* output, @@ -262,11 +204,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax(const T* input, } template -__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, - const T* bias, +__launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_softmax, T* output, IdxT* indices, int* source_rows, + T* softmax_max_prob, const int64_t num_experts, const int64_t k, const int64_t num_rows) { @@ -293,7 +235,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { const int idx = thread_read_offset + expert; inp_kvp.key = expert; - inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; + inp_kvp.value = inputs_after_softmax[idx]; for (int prior_k = 0; prior_k < k_idx; ++prior_k) { const IdxT prior_winning_expert = indices[k * block_row + prior_k]; @@ -310,7 +252,8 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int idx = k * block_row + k_idx; - output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + // restore normalized probes + output[idx] = result_kvp.value / T(softmax_max_prob[idx]); indices[idx] = should_process_row ? result_kvp.key : num_experts; source_rows[idx] = k_idx * num_rows + block_row; } @@ -318,93 +261,8 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, } } -template -__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input, - const T* bias, - T* output, - IdxT* indices, - int* source_rows, - const int64_t num_experts, - const int64_t k, - const int64_t num_rows) { - // softmax - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; - - __shared__ float normalizing_factor; - __shared__ float float_max; - - int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; - if (globalIdx >= num_rows) { - return; - } - const int64_t thread_row_offset = globalIdx * num_experts; - const int64_t idx = thread_row_offset+threadIdx.x; - - cub::Sum sum; - - float threadData = (threadIdx.x < num_experts) ? static_cast(input[idx]) :(-FLT_MAX); - - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); - if (threadIdx.x == 0) { - float_max = maxElem; - } - __syncthreads(); - - float threadDataSub = threadData - float_max; - float threadDataExp = exp(threadDataSub); - - const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum); - - if (threadIdx.x == 0) { - normalizing_factor = 1.f / Z; - } - __syncthreads(); - - T val = T(threadDataExp * normalizing_factor); - - // top_k - using cub_kvp = cub::KeyValuePair; - using BlockReduceP = cub::BlockReduce; - __shared__ typename BlockReduceP::TempStorage tmpStorageP; - - cub_kvp thread_kvp; - cub::ArgMax arg_max; - - for (int k_idx = 0; k_idx < k; ++k_idx) { - thread_kvp.key = 0; - thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities - - if (threadIdx.x < num_experts) { - cub_kvp inp_kvp; - int expert = threadIdx.x; - inp_kvp.key = expert; - inp_kvp.value = bias ? val + bias[expert] : val; - - for (int prior_k = 0; prior_k < k_idx; ++prior_k) { - const IdxT prior_winning_expert = indices[k * globalIdx + prior_k]; - - if (prior_winning_expert == expert) { - inp_kvp = thread_kvp; - } - } - thread_kvp = arg_max(inp_kvp, thread_kvp); - } - - const cub_kvp result_kvp = - BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max); - if (threadIdx.x == 0) { - const int cur_idx = k * globalIdx + k_idx; - output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; - indices[cur_idx] = result_kvp.key; - source_rows[cur_idx] = k_idx * num_rows + globalIdx; - } - __syncthreads(); - } -} - -template -__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax, +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const T* bias, T* output, IdxT* indices, @@ -427,10 +285,12 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so const bool should_process_row = true; const int thread_read_offset = block_row * num_experts; T weight_sum = static_cast(0); + T* row_outputs = nullptr; - extern __shared__ char smem[]; - - T* row_outputs = reinterpret_cast(smem); + if constexpr (NormWeights){ + extern __shared__ char smem[]; + row_outputs = reinterpret_cast(smem); + } for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; @@ -457,28 +317,32 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int idx = k * block_row + k_idx; - // output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; indices[idx] = should_process_row ? result_kvp.key : num_experts; source_rows[idx] = k_idx * num_rows + block_row; - T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; - row_outputs[k_idx] = row_out; - weight_sum += row_out; + if constexpr (NormWeights){ + T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + row_outputs[k_idx] = row_out; + weight_sum += row_out; + } + else{ + output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + } } __syncthreads(); } - if (threadIdx.x < WARP_SIZE) { - weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); - } - - if (threadIdx.x < k) { - output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; + if constexpr (NormWeights){ + if (threadIdx.x < WARP_SIZE) { + weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); + } + if (threadIdx.x < k) { + output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; + } } } - -template -__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input, +template +__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input, const T* bias, T* output, IdxT* indices, @@ -532,8 +396,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i cub::ArgMax arg_max; T weight_sum = static_cast(0); - extern __shared__ char smem[]; - T* row_outputs = reinterpret_cast(smem); + T* row_outputs = nullptr; + if constexpr (NormWeights){ + extern __shared__ char smem[]; + row_outputs = reinterpret_cast(smem); + } for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; @@ -560,22 +427,28 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i if (threadIdx.x == 0) { const int cur_idx = k * globalIdx + k_idx; - T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; - row_outputs[k_idx] = row_out; - weight_sum += row_out; - indices[cur_idx] = result_kvp.key; source_rows[cur_idx] = k_idx * num_rows + globalIdx; + + if constexpr (NormWeights) { + T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; + row_outputs[k_idx] = row_out; + weight_sum += row_out; + } + else { + output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; + } } __syncthreads(); } + if constexpr (NormWeights) { + if (threadIdx.x < WARP_SIZE) { + weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); + } - if (threadIdx.x < WARP_SIZE) { - weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); - } - - if (threadIdx.x < k) { - output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; + if (threadIdx.x < k) { + output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; + } } } @@ -1015,7 +888,7 @@ static void run(const T* input, group_experts, softmax_num_rows); const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); - moe_top_k + group_moe_top_k <<>>(softmax, output, indices, diff --git a/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu b/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu index 0a7b5ac6a8..8f780e00af 100644 --- a/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu +++ b/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu @@ -102,7 +102,7 @@ void moe_redundant_topk_select_kernel(const T* input, else { assert(k<=TPB); if (apply_norm_weight) { - moe_softmax_top_k_normed_fused + moe_softmax_top_k_fused <<>>(input, bias, output, @@ -112,7 +112,7 @@ void moe_redundant_topk_select_kernel(const T* input, k, num_rows); } else { - moe_softmax_top_k_fused + moe_softmax_top_k_fused <<>>(input, bias, output, diff --git a/custom_ops/gpu_ops/moe/moe_topk_select.cu b/custom_ops/gpu_ops/moe/moe_topk_select.cu index 7647a0ed69..7217682f45 100644 --- a/custom_ops/gpu_ops/moe/moe_topk_select.cu +++ b/custom_ops/gpu_ops/moe/moe_topk_select.cu @@ -68,7 +68,7 @@ void moe_topk_select_kernel(const T* input, moe_softmax<<>>( input, softmax, num_experts, num_rows); if (apply_norm_weight) { - moe_top_k_normed + moe_top_k <<>>(softmax, bias, output, @@ -78,7 +78,7 @@ void moe_topk_select_kernel(const T* input, k, num_rows); } else { - moe_top_k + moe_top_k <<>>(softmax, bias, output, @@ -93,7 +93,7 @@ void moe_topk_select_kernel(const T* input, else { assert(k<=TPB); if (apply_norm_weight) { - moe_softmax_top_k_normed_fused + moe_softmax_top_k_fused <<>>(input, bias, output, @@ -103,7 +103,7 @@ void moe_topk_select_kernel(const T* input, k, num_rows); } else { - moe_softmax_top_k_fused + moe_softmax_top_k_fused <<>>(input, bias, output, diff --git a/test/operators/test_moe_top_k_select.py b/test/operators/test_moe_top_k_select.py new file mode 100644 index 0000000000..63d93e067c --- /dev/null +++ b/test/operators/test_moe_top_k_select.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import moe_topk_select + + +class Test(unittest.TestCase): + def setUp(self): + """ + Initialize. + """ + paddle.seed(2024) + print(paddle.device.cuda.get_device_properties()) + print(paddle.__git_commit__) + self.batch_size = 1500 + self.num_experts = 128 + self.top_k = 8 + + def moe_topk_select_ref(self, gate_out: paddle.Tensor, bias: paddle.Tensor, top_k: int, apply_norm_weight: bool): + gate_out_after_softmax = paddle.nn.functional.softmax(gate_out, axis=-1) + topk_weights_ref, topk_ids_ref = paddle.topk(gate_out_after_softmax, k=top_k, axis=-1) + + if bias is not None: + gate_out_after_softmax_bias = gate_out_after_softmax + bias + _, topk_ids_ref = paddle.topk(gate_out_after_softmax_bias, k=top_k, axis=-1) + batch_indices = paddle.arange(gate_out.shape[0]).unsqueeze(-1).expand_as(topk_ids_ref) + topk_weights_ref = gate_out_after_softmax.gather_nd(paddle.stack([batch_indices, topk_ids_ref], axis=-1)) + + if apply_norm_weight: + topk_weights_ref = topk_weights_ref / topk_weights_ref.sum(axis=-1, keepdim=True) + + return topk_ids_ref, topk_weights_ref + + def test_moe_topk_select(self): + """ + Check moe_topk_select. + """ + gate_out = paddle.rand([self.batch_size, self.num_experts], dtype="float32") + gate_correction_bias = paddle.rand([1, self.num_experts], dtype="float32") + gate_correction_bias = gate_correction_bias / 10.0 + + for apply_norm_weight in [True, False]: + for bias in [None, gate_correction_bias]: + topk_ids_ref, topk_weights_ref = self.moe_topk_select_ref( + gate_out, bias, self.top_k, apply_norm_weight + ) + for fused in [True, False]: + topk_ids, topk_weights = moe_topk_select( + gate_out, + bias, + self.top_k, + apply_norm_weight, + fused, + ) + + np.testing.assert_allclose( + topk_ids_ref, + topk_ids, + rtol=1e-05, + atol=1e-05, + ) + + np.testing.assert_allclose( + topk_weights_ref, + topk_weights, + rtol=1e-05, + atol=1e-05, + ) + + +if __name__ == "__main__": + unittest.main()