|
| 1 | +/* Copyright 2025 The xLLM Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + https://github.com/jd-opensource/xllm/blob/main/LICENSE |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#include "mlu_ops_api.h" |
| 17 | + |
| 18 | +namespace xllm::kernel::mlu { |
| 19 | + |
| 20 | +torch::Tensor scaled_matmul( |
| 21 | + const torch::Tensor& a, |
| 22 | + const torch::Tensor& b, |
| 23 | + const std::optional<torch::Tensor>& a_scale, |
| 24 | + const torch::Tensor& b_scale, |
| 25 | + c10::ScalarType output_dtype, |
| 26 | + const std::optional<torch::Tensor>& bias /* = c10::nullopt */, |
| 27 | + const std::optional<torch::Tensor>& c /* = c10::nullopt */, |
| 28 | + const std::string& act_mode /* = "none" */, |
| 29 | + int64_t quant_bit_size /* = 8 */, |
| 30 | + double alpha /* = 1.0 */, |
| 31 | + double beta /* = 1.0 */, |
| 32 | + bool use_hp_active /* = false */, |
| 33 | + int64_t a_quant_bit_size /* = -1 */, |
| 34 | + const std::optional<torch::Tensor>& a_calib /* = c10::nullopt */, |
| 35 | + const std::optional<torch::Tensor>& b_calib /* = c10::nullopt */, |
| 36 | + const std::optional<torch::Tensor>& output /* = c10::nullopt */ |
| 37 | +) { |
| 38 | + // Check: only support w8a8 quantization for now. |
| 39 | + TORCH_CHECK(quant_bit_size == 8 && a_quant_bit_size == 8, |
| 40 | + "scaled_matmul only supports w8a8 quantization (quant_bit_size " |
| 41 | + "== 8, a_quant_bit_size == 8) for now. " |
| 42 | + "Got quant_bit_size = ", |
| 43 | + quant_bit_size, |
| 44 | + ", a_quant_bit_size = ", |
| 45 | + a_quant_bit_size, |
| 46 | + "."); |
| 47 | + |
| 48 | + // Only support smooth_quant algorithm for now |
| 49 | + std::string quant_algo = "smooth_quant"; |
| 50 | + std::string a_quant_layout = (a_scale.value().dim() == 1) |
| 51 | + ? "quantize_per_token" |
| 52 | + : "quantize_group_wise"; |
| 53 | + std::string b_quant_layout = "quantize_per_channel"; |
| 54 | + if (b_scale.dim() > 1) { |
| 55 | + if (b_scale.size(0) < b.size(0)) { |
| 56 | + b_quant_layout = "quantize_per_block"; |
| 57 | + } else { |
| 58 | + b_quant_layout = "quantize_group_wise"; |
| 59 | + } |
| 60 | + } |
| 61 | + std::optional<torch::Tensor> gemm_output_scale = c10::nullopt; |
| 62 | + |
| 63 | + at::ScalarType torch_half = at::ScalarType::Half; |
| 64 | + at::ScalarType torch_bfloat16 = at::ScalarType::BFloat16; |
| 65 | + |
| 66 | + TORCH_CHECK(output_dtype == torch_half || output_dtype == torch_bfloat16, |
| 67 | + "output dtype must be half or bfloat16, but got: ", |
| 68 | + output_dtype, |
| 69 | + "."); |
| 70 | + |
| 71 | + // Select output tensor |
| 72 | + torch::Tensor output_tensor; |
| 73 | + if (output.has_value()) { |
| 74 | + output_tensor = output.value(); |
| 75 | + } else { |
| 76 | + output_tensor = at::empty( |
| 77 | + {a.size(0), b.size(0)}, |
| 78 | + torch::TensorOptions().dtype(output_dtype).device(a.device())); |
| 79 | + } |
| 80 | + |
| 81 | + // Call underlying kernel for smooth_quant |
| 82 | + tmo::torch_api::scaled_matmul(output_tensor, |
| 83 | + a, |
| 84 | + b, |
| 85 | + a_scale, |
| 86 | + c10::nullopt, // a_zero |
| 87 | + a_calib, |
| 88 | + b_scale, |
| 89 | + c10::nullopt, // b_zero |
| 90 | + b_calib, |
| 91 | + bias, |
| 92 | + c, |
| 93 | + c10::nullopt, // c_scale |
| 94 | + c10::nullopt, // c_zero |
| 95 | + gemm_output_scale, |
| 96 | + c10::nullopt, // gemm_output_zero |
| 97 | + quant_algo, |
| 98 | + a_quant_layout, |
| 99 | + b_quant_layout, |
| 100 | + a_quant_bit_size, |
| 101 | + quant_bit_size, |
| 102 | + act_mode, |
| 103 | + use_hp_active, |
| 104 | + 1.0, // act_coef |
| 105 | + alpha, |
| 106 | + beta, |
| 107 | + false, // trans_a |
| 108 | + true // trans_b |
| 109 | + ); |
| 110 | + return output_tensor; |
| 111 | +} |
| 112 | + |
| 113 | +} // namespace xllm::kernel::mlu |
0 commit comments