|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +#pragma once |
| 17 | + |
| 18 | +#include "helper.h" |
| 19 | +#include "mc_fused_moe_helper.h" |
| 20 | +#include "fused_moe_op.h" |
| 21 | + |
| 22 | +__global__ void compute_total_rows_before_expert_kernel( |
| 23 | + int* sorted_experts, |
| 24 | + const int64_t sorted_experts_len, |
| 25 | + const int64_t num_experts, |
| 26 | + int32_t* total_rows_before_expert) { |
| 27 | + const int expert = blockIdx.x * blockDim.x + threadIdx.x; |
| 28 | + if (expert >= num_experts) return; |
| 29 | + |
| 30 | + total_rows_before_expert[expert] = |
| 31 | + find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); |
| 32 | +} |
| 33 | + |
| 34 | +void compute_total_rows_before_expert(int* sorted_indices, |
| 35 | + const int64_t total_indices, |
| 36 | + const int64_t num_experts, |
| 37 | + int32_t* total_rows_before_expert, |
| 38 | + cudaStream_t stream) { |
| 39 | + const int threads = std::min(int64_t(1024), num_experts); |
| 40 | + const int blocks = (num_experts + threads - 1) / threads; |
| 41 | + |
| 42 | + compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>( |
| 43 | + sorted_indices, total_indices, num_experts, total_rows_before_expert); |
| 44 | +} |
| 45 | + |
| 46 | +template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC> |
| 47 | +void FusedMoeKernel(const paddle::Tensor& input, |
| 48 | + const paddle::Tensor& gate_weight, |
| 49 | + const paddle::Tensor& ffn1_weight, |
| 50 | + const paddle::optional<paddle::Tensor>& ffn1_scale, |
| 51 | + const paddle::optional<paddle::Tensor>& ffn1_bias, |
| 52 | + const paddle::Tensor& ffn2_weight, |
| 53 | + const paddle::optional<paddle::Tensor>& ffn2_scale, |
| 54 | + const paddle::optional<paddle::Tensor>& ffn2_bias, |
| 55 | + const std::string& quant_method, |
| 56 | + const int moe_topk, |
| 57 | + const bool group_moe, |
| 58 | + const bool norm_topk_prob, |
| 59 | + paddle::Tensor* output) { |
| 60 | + typedef PDTraits<T> traits_; |
| 61 | + typedef typename traits_::DataType DataType_; |
| 62 | + typedef typename traits_::data_t data_t; |
| 63 | + |
| 64 | + auto* output_data = output->data<data_t>(); |
| 65 | + |
| 66 | + auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method); |
| 67 | + |
| 68 | + moe_compute.computeFFN( |
| 69 | + &input, |
| 70 | + &gate_weight, |
| 71 | + &ffn1_weight, |
| 72 | + ffn1_scale ? ffn1_scale.get_ptr() : nullptr, |
| 73 | + ffn1_bias ? ffn1_bias.get_ptr() : nullptr, |
| 74 | + &ffn2_weight, |
| 75 | + ffn2_scale ? ffn2_scale.get_ptr() : nullptr, |
| 76 | + ffn2_bias ? ffn2_bias.get_ptr() : nullptr, |
| 77 | + nullptr, |
| 78 | + moe_topk, |
| 79 | + group_moe, |
| 80 | + norm_topk_prob, |
| 81 | + 1.0, // ComputeFFN |
| 82 | + "ffn", |
| 83 | + output); |
| 84 | +} |
| 85 | + |
| 86 | + |
| 87 | +std::vector<paddle::Tensor> FusedExpertMoe( |
| 88 | + const paddle::Tensor& input, |
| 89 | + const paddle::Tensor& gate_weight, |
| 90 | + const paddle::Tensor& ffn1_weight, |
| 91 | + const paddle::Tensor& ffn2_weight, |
| 92 | + const paddle::optional<paddle::Tensor>& ffn1_bias, |
| 93 | + const paddle::optional<paddle::Tensor>& ffn1_scale, |
| 94 | + const paddle::optional<paddle::Tensor>& ffn2_bias, |
| 95 | + const paddle::optional<paddle::Tensor>& ffn2_scale, |
| 96 | + const std::string& quant_method, |
| 97 | + const int moe_topk, |
| 98 | + const bool norm_topk_prob, |
| 99 | + const bool group_moe) { |
| 100 | + const auto input_type = input.dtype(); |
| 101 | + auto output = paddle::empty_like(input); |
| 102 | + |
| 103 | + switch (input_type) { |
| 104 | + case paddle::DataType::BFLOAT16: |
| 105 | + FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input, |
| 106 | + gate_weight, |
| 107 | + ffn1_weight, |
| 108 | + ffn1_scale, |
| 109 | + ffn1_bias, |
| 110 | + ffn2_weight, |
| 111 | + ffn2_scale, |
| 112 | + ffn2_bias, |
| 113 | + quant_method, |
| 114 | + moe_topk, |
| 115 | + group_moe, |
| 116 | + norm_topk_prob, |
| 117 | + &output); |
| 118 | + break; |
| 119 | + // case paddle::DataType::FLOAT16: |
| 120 | + // FusedMoeKernel<paddle::DataType::FLOAT16>(input, |
| 121 | + // gate_weight, |
| 122 | + // ffn1_weight, |
| 123 | + // ffn1_scale, |
| 124 | + // ffn1_bias, |
| 125 | + // ffn2_weight, |
| 126 | + // ffn2_scale, |
| 127 | + // ffn2_bias, |
| 128 | + // quant_method, |
| 129 | + // moe_topk, |
| 130 | + // group_moe, |
| 131 | + // norm_topk_prob, |
| 132 | + // &output); |
| 133 | + // break; |
| 134 | + default: |
| 135 | + PD_THROW("Only support bf16 for FusedMoeKernel"); |
| 136 | + } |
| 137 | + return {output}; |
| 138 | +} |
| 139 | + |
| 140 | +std::vector<std::vector<int64_t>> FusedExpertMoeInferShape( |
| 141 | + const std::vector<int64_t>& input_shape, |
| 142 | + const std::vector<int64_t>& gate_weight_shape, |
| 143 | + const std::vector<int64_t>& ffn1_weight_shape, |
| 144 | + const std::vector<int64_t>& ffn2_weight_shape, |
| 145 | + const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape, |
| 146 | + const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape, |
| 147 | + const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape, |
| 148 | + const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) { |
| 149 | + return {input_shape}; |
| 150 | +} |
| 151 | + |
| 152 | +std::vector<paddle::DataType> FusedExpertMoeInferDtype( |
| 153 | + const paddle::DataType& input_dtype, |
| 154 | + const paddle::DataType& gate_weight_dtype, |
| 155 | + const paddle::DataType& ffn1_weight_dtype, |
| 156 | + const paddle::DataType& ffn2_weight_dtype, |
| 157 | + const paddle::optional<paddle::DataType>& ffn1_bias_dtype, |
| 158 | + const paddle::optional<paddle::DataType>& ffn1_scale_dtype, |
| 159 | + const paddle::optional<paddle::DataType>& ffn2_bias_dtype, |
| 160 | + const paddle::optional<paddle::DataType>& ffn2_scale_dtype) { |
| 161 | + return {input_dtype}; |
| 162 | +} |
| 163 | + |
| 164 | + |
| 165 | +PD_BUILD_OP(fused_expert_moe) |
| 166 | + .Inputs({"input", |
| 167 | + "gate_weight", |
| 168 | + "ffn1_weight", |
| 169 | + "ffn2_weight", |
| 170 | + paddle::Optional("ffn1_bias"), |
| 171 | + paddle::Optional("ffn1_scale"), |
| 172 | + paddle::Optional("ffn2_bias"), |
| 173 | + paddle::Optional("ffn2_scale")}) |
| 174 | + .Outputs({"output"}) |
| 175 | + .Attrs({"quant_method:std::string", |
| 176 | + "moe_topk:int", |
| 177 | + "norm_topk_prob:bool", |
| 178 | + "group_moe:bool"}) |
| 179 | + .SetKernelFn(PD_KERNEL(FusedExpertMoe)) |
| 180 | + .SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape)) |
| 181 | + .SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype)); |
0 commit comments