|
| 1 | +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "mctlass/epilogue/thread/scale_type.h" |
| 16 | +#include "mctlass/half.h" |
| 17 | +#include "mctlass/layout/matrix.h" |
| 18 | +#include "mctlass/mctlass_ex.h" |
| 19 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
| 20 | +#include "paddle/phi/common/datatype_traits.h" |
| 21 | +#include "paddle/phi/core/kernel_registry.h" |
| 22 | +#include "paddle/phi/kernels/funcs/weight_only_gemv.h" |
| 23 | +#include "paddle/phi/kernels/weight_only_linear_kernel.h" |
| 24 | + |
| 25 | +namespace phi { |
| 26 | + |
| 27 | +template <typename T, typename Context> |
| 28 | +void WeightOnlyLinearKernel(const Context& dev_ctx, |
| 29 | + const DenseTensor& x, |
| 30 | + const DenseTensor& weight, |
| 31 | + const paddle::optional<DenseTensor>& bias, |
| 32 | + const DenseTensor& weight_scale, |
| 33 | + const std::string& weight_dtype, |
| 34 | + const int32_t arch, |
| 35 | + const int32_t group_size, |
| 36 | + DenseTensor* out) { |
| 37 | + dev_ctx.template Alloc<T>(out); |
| 38 | + const T* x_data = x.data<T>(); |
| 39 | + const int8_t* weight_data = weight.data<int8_t>(); |
| 40 | + const T* bias_data = bias ? bias.get().data<T>() : nullptr; |
| 41 | + const T* weight_scale_data = weight_scale.data<T>(); |
| 42 | + T* out_data = out->data<T>(); |
| 43 | + const auto x_dims = x.dims(); |
| 44 | + const auto w_dims = weight.dims(); |
| 45 | + int n = group_size > 0 ? weight_scale.dims()[1] : weight_scale.dims()[0]; |
| 46 | + int k = w_dims[1]; |
| 47 | + int m = x.numel() / k; |
| 48 | + |
| 49 | + using ElementA = maca_bfloat16; |
| 50 | + using ElementB_w8a16 = int8_t; |
| 51 | + using ElementB_w4a16 = uint8_t; |
| 52 | + using ElementC = maca_bfloat16; |
| 53 | + using ElementCompute = float; |
| 54 | + using ElementOutput = ElementC; |
| 55 | + using LayoutA = mctlass::layout::RowMajor; |
| 56 | + using LayoutB = mctlass::layout::ColumnMajor; |
| 57 | + using LayoutC = mctlass::layout::RowMajor; |
| 58 | + using ArchTag = mctlass::arch::Sm80; |
| 59 | + |
| 60 | + using mctlassGemmScaleOp_w8a16_nobias = |
| 61 | + mctlassGemmScale<ElementA, |
| 62 | + LayoutA, |
| 63 | + ElementB_w8a16, |
| 64 | + LayoutB, |
| 65 | + ElementC, |
| 66 | + LayoutC, |
| 67 | + ElementCompute, |
| 68 | + ArchTag, |
| 69 | + mctlass::epilogue::thread::ScaleType::NoScaleAsBs>; |
| 70 | + |
| 71 | + using mctlassGemmScaleOp_w8a16_bias = |
| 72 | + mctlassGemmScale<ElementA, |
| 73 | + LayoutA, |
| 74 | + ElementB_w8a16, |
| 75 | + LayoutB, |
| 76 | + ElementC, |
| 77 | + LayoutC, |
| 78 | + ElementCompute, |
| 79 | + ArchTag, |
| 80 | + mctlass::epilogue::thread::ScaleType::ScaleOnlyBias>; |
| 81 | + |
| 82 | + using mctlassGemmScaleOp_w4a16_nobias = |
| 83 | + mctlassGemmScale<ElementA, |
| 84 | + LayoutA, |
| 85 | + ElementB_w4a16, |
| 86 | + LayoutB, |
| 87 | + ElementC, |
| 88 | + LayoutC, |
| 89 | + ElementCompute, |
| 90 | + ArchTag, |
| 91 | + mctlass::epilogue::thread::ScaleType::NoScaleAsBs>; |
| 92 | + |
| 93 | + using mctlassGemmScaleOp_w4a16_bias = |
| 94 | + mctlassGemmScale<ElementA, |
| 95 | + LayoutA, |
| 96 | + ElementB_w4a16, |
| 97 | + LayoutB, |
| 98 | + ElementC, |
| 99 | + LayoutC, |
| 100 | + ElementCompute, |
| 101 | + ArchTag, |
| 102 | + mctlass::epilogue::thread::ScaleType::ScaleOnlyBias>; |
| 103 | + |
| 104 | + mctlass::gemm::GemmCoord problem_size(m, n, k); |
| 105 | + |
| 106 | + if (weight_dtype == "int8") { |
| 107 | + if (bias_data == nullptr) { |
| 108 | + mctlassGemmScaleOp_w8a16_nobias mctlass_op; |
| 109 | + typename mctlassGemmScaleOp_w8a16_nobias::Arguments arguments{ |
| 110 | + mctlass::gemm::GemmUniversalMode::kGemmQuantB, |
| 111 | + problem_size, |
| 112 | + 1, |
| 113 | + mctlassGemmScaleOp_w8a16_nobias::epilogueParams( |
| 114 | + reinterpret_cast<const maca_bfloat16*>(bias_data)), |
| 115 | + mctlassGemmScaleOp_w8a16_nobias::quantscaleParams( |
| 116 | + 1, |
| 117 | + group_size, |
| 118 | + reinterpret_cast<const maca_bfloat16*>(weight_scale_data)), |
| 119 | + reinterpret_cast<const maca_bfloat16*>(x_data), |
| 120 | + weight_data, |
| 121 | + reinterpret_cast<const maca_bfloat16*>(out_data), |
| 122 | + out_data, |
| 123 | + m * k, |
| 124 | + n * k, |
| 125 | + m * n, |
| 126 | + m * n, |
| 127 | + k, |
| 128 | + k, |
| 129 | + n, |
| 130 | + n}; |
| 131 | + mctlass_op(arguments); |
| 132 | + } else { |
| 133 | + mctlassGemmScaleOp_w8a16_bias mctlass_op; |
| 134 | + typename mctlassGemmScaleOp_w8a16_bias::Arguments arguments{ |
| 135 | + mctlass::gemm::GemmUniversalMode::kGemmQuantB, |
| 136 | + problem_size, |
| 137 | + 1, |
| 138 | + mctlassGemmScaleOp_w8a16_bias::epilogueParams( |
| 139 | + reinterpret_cast<const maca_bfloat16*>(bias_data)), |
| 140 | + mctlassGemmScaleOp_w8a16_bias::quantscaleParams( |
| 141 | + 1, |
| 142 | + group_size, |
| 143 | + reinterpret_cast<const maca_bfloat16*>(weight_scale_data)), |
| 144 | + reinterpret_cast<const maca_bfloat16*>(x_data), |
| 145 | + weight_data, |
| 146 | + reinterpret_cast<const maca_bfloat16*>(out_data), |
| 147 | + out_data, |
| 148 | + m * k, |
| 149 | + n * k, |
| 150 | + m * n, |
| 151 | + m * n, |
| 152 | + k, |
| 153 | + k, |
| 154 | + n, |
| 155 | + n}; |
| 156 | + mctlass_op(arguments); |
| 157 | + } |
| 158 | + } else if (weight_dtype == "int4") { |
| 159 | + if (bias_data == nullptr) { |
| 160 | + mctlassGemmScaleOp_w4a16_nobias mctlass_op; |
| 161 | + typename mctlassGemmScaleOp_w4a16_nobias::Arguments arguments{ |
| 162 | + mctlass::gemm::GemmUniversalMode::kGemmQuantB, |
| 163 | + problem_size, |
| 164 | + 1, |
| 165 | + mctlassGemmScaleOp_w4a16_nobias::epilogueParams( |
| 166 | + reinterpret_cast<const maca_bfloat16*>(bias_data)), |
| 167 | + mctlassGemmScaleOp_w4a16_nobias::quantscaleParams( |
| 168 | + 1, |
| 169 | + group_size, |
| 170 | + reinterpret_cast<const maca_bfloat16*>(weight_scale_data)), |
| 171 | + reinterpret_cast<const maca_bfloat16*>(x_data), |
| 172 | + weight_data, |
| 173 | + reinterpret_cast<const maca_bfloat16*>(out_data), |
| 174 | + out_data, |
| 175 | + m * k, |
| 176 | + n * k, |
| 177 | + m * n, |
| 178 | + m * n, |
| 179 | + k, |
| 180 | + k, |
| 181 | + n, |
| 182 | + n}; |
| 183 | + mctlass_op(arguments); |
| 184 | + } else { |
| 185 | + mctlassGemmScaleOp_w4a16_bias mctlass_op; |
| 186 | + typename mctlassGemmScaleOp_w4a16_bias::Arguments arguments{ |
| 187 | + mctlass::gemm::GemmUniversalMode::kGemmQuantB, |
| 188 | + problem_size, |
| 189 | + 1, |
| 190 | + mctlassGemmScaleOp_w4a16_bias::epilogueParams( |
| 191 | + reinterpret_cast<const maca_bfloat16*>(bias_data)), |
| 192 | + mctlassGemmScaleOp_w4a16_bias::quantscaleParams( |
| 193 | + 1, |
| 194 | + group_size, |
| 195 | + reinterpret_cast<const maca_bfloat16*>(weight_scale_data)), |
| 196 | + reinterpret_cast<const maca_bfloat16*>(x_data), |
| 197 | + weight_data, |
| 198 | + reinterpret_cast<const maca_bfloat16*>(out_data), |
| 199 | + out_data, |
| 200 | + m * k, |
| 201 | + n * k, |
| 202 | + m * n, |
| 203 | + m * n, |
| 204 | + k, |
| 205 | + k, |
| 206 | + n, |
| 207 | + n}; |
| 208 | + mctlass_op(arguments); |
| 209 | + } |
| 210 | + } |
| 211 | +} |
| 212 | +} // namespace phi |
| 213 | + |
| 214 | +PD_REGISTER_PLUGIN_KERNEL(weight_only_linear, |
| 215 | + metax_gpu, |
| 216 | + ALL_LAYOUT, |
| 217 | + phi::WeightOnlyLinearKernel, |
| 218 | + phi::dtype::float16, |
| 219 | + phi::dtype::bfloat16) {} |
0 commit comments