|
| 1 | +/* |
| 2 | + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "tensorrt_llm/thop/fp8Quantize.h" |
| 18 | + |
| 19 | +#include <ATen/cuda/EmptyTensor.h> |
| 20 | + |
| 21 | +#include "cutlass/numeric_types.h" |
| 22 | +#include "pytorch_extension_utils.h" |
| 23 | +#include "tensorrt_llm/thop/thUtils.h" |
| 24 | + |
| 25 | +namespace torch_ext { |
| 26 | + |
| 27 | +// input: [M, K], fp32/fp16/bf16/fp8_quantized |
| 28 | +// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in |
| 29 | +// linear layout. See FP4QuantizationSFLayout enum for more details about the two layouts. |
| 30 | +// returns |
| 31 | +std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor input, bool isSfSwizzledLayout) { |
| 32 | + CHECK_TH_CUDA(input); |
| 33 | + CHECK_CONTIGUOUS(input); |
| 34 | + |
| 35 | + auto const& inputShape = input.sizes(); |
| 36 | + auto const& rank = inputShape.size(); |
| 37 | + |
| 38 | + TORCH_CHECK(rank >= 2, "Input should be >=2D tensor."); |
| 39 | + int64_t m = 1; |
| 40 | + for (size_t i = 0; i < rank - 1; i++) { |
| 41 | + m *= inputShape[i]; |
| 42 | + } |
| 43 | + auto const k = inputShape[rank - 1]; |
| 44 | + int32_t const sfVecSize = 32; |
| 45 | + TORCH_CHECK(k % sfVecSize == 0); |
| 46 | + |
| 47 | + std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end()); |
| 48 | + outputShape[rank - 1] = k; |
| 49 | + |
| 50 | + at::Tensor valueFP8 = |
| 51 | + at::detail::empty_cuda(outputShape, at::ScalarType::Float8_e4m3fn, input.device(), |
| 52 | + /* stride */ std::nullopt); |
| 53 | + |
| 54 | + int64_t SFSize = isSfSwizzledLayout |
| 55 | + ? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize) |
| 56 | + : tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize); |
| 57 | + |
| 58 | + at::Tensor scaleFP8SF = at::detail::empty_cuda({SFSize}, SF_DTYPE, input.device(), |
| 59 | + /* stride */ std::nullopt); // 1D tensor |
| 60 | + |
| 61 | + const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); |
| 62 | + |
| 63 | + auto const layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4 |
| 64 | + : tensorrt_llm::FP4QuantizationSFLayout::LINEAR; |
| 65 | + |
| 66 | +#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \ |
| 67 | + tensorrt_llm::kernels::invokeMxFP8Quantization<T>( \ |
| 68 | + 1, m, k, reinterpret_cast<T*>(input.data_ptr()), \ |
| 69 | + reinterpret_cast<int64_t*>(valueFP8.data_ptr()), \ |
| 70 | + reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, \ |
| 71 | + at::cuda::getCurrentCUDAStream(input.get_device())); |
| 72 | + |
| 73 | + if (input.scalar_type() == at::ScalarType::Half) { |
| 74 | + LAUNCH_MXFP8_QUANTIZE_KERNEL(half) |
| 75 | + } else if (input.scalar_type() == at::ScalarType::BFloat16) { |
| 76 | +#ifdef ENABLE_BF16 |
| 77 | + LAUNCH_MXFP8_QUANTIZE_KERNEL(__nv_bfloat16) |
| 78 | +#else |
| 79 | + C10_THROW_ERROR(NotImplementedError, |
| 80 | + "BFloat16 must be enabled to quantize an bf16 tensor to mxfp8."); |
| 81 | +#endif |
| 82 | + } else { |
| 83 | + C10_THROW_ERROR(NotImplementedError, |
| 84 | + "mxfp8_quantize only supports input tensor with dtypes fp16/bf16."); |
| 85 | + } |
| 86 | + |
| 87 | +#undef LAUNCH_MXFP8_QUANTIZE_KERNEL |
| 88 | + |
| 89 | + return {valueFP8, scaleFP8SF}; |
| 90 | +} |
| 91 | + |
| 92 | +inline uint8_t float_to_ue8m0(float value) { |
| 93 | + if (value == 0.0f) { |
| 94 | + return 0x00; |
| 95 | + } |
| 96 | + constexpr uint32_t FP32_MANTISSA_BITS = 23; |
| 97 | + uint32_t val_u32 = *reinterpret_cast<uint32_t*>(&value); |
| 98 | + uint8_t exponent = (val_u32 >> FP32_MANTISSA_BITS); |
| 99 | + uint32_t mantissa = val_u32 & 0x7FFFFF; |
| 100 | + // Round up exponent and deal with satfinite. |
| 101 | + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { |
| 102 | + ++exponent; |
| 103 | + } |
| 104 | + return exponent; |
| 105 | +} |
| 106 | + |
| 107 | +// Used in tests to quantize mxe4m3 tensors on host. |
| 108 | +std::tuple<at::Tensor, at::Tensor> mxfp8_quantize_host(at::Tensor x_fp32, |
| 109 | + bool is_sf_swizzled_layout) { |
| 110 | + int32_t const sf_vec_size = 32; |
| 111 | + CHECK_CPU_INPUT(x_fp32, c10::ScalarType::Float); |
| 112 | + auto data_shape = x_fp32.sizes(); |
| 113 | + TORCH_CHECK(data_shape.size() == 2, "x_fp32 should be 2D tensor."); |
| 114 | + int num_tokens = data_shape[0]; |
| 115 | + int hidden_dim = data_shape[1]; |
| 116 | + int groups_per_hidden_dim = hidden_dim / sf_vec_size; |
| 117 | + |
| 118 | + at::Tensor fp8_tensor = at::detail::empty_cpu({num_tokens, hidden_dim}, at::ScalarType::Byte, |
| 119 | + /* pinned */ true, at::MemoryFormat::Contiguous); |
| 120 | + int64_t sf_size = |
| 121 | + is_sf_swizzled_layout |
| 122 | + ? tensorrt_llm::computeFP4SwizzledLayoutSFSize(num_tokens, hidden_dim / sf_vec_size) |
| 123 | + : tensorrt_llm::computeFP4LinearLayoutSFSize(num_tokens, hidden_dim / sf_vec_size); |
| 124 | + at::Tensor scale_tensor = |
| 125 | + at::detail::empty_cpu({sf_size}, SF_DTYPE, /* pinned */ true, at::MemoryFormat::Contiguous); |
| 126 | + |
| 127 | + tensorrt_llm::FP4QuantizationSFLayout layout = |
| 128 | + is_sf_swizzled_layout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4 |
| 129 | + : tensorrt_llm::FP4QuantizationSFLayout::LINEAR; |
| 130 | + |
| 131 | + for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) { |
| 132 | + for (int group = 0; group < groups_per_hidden_dim; ++group) { |
| 133 | + float* fp32_ptr = x_fp32.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size; |
| 134 | + uint8_t* fp8_ptr = fp8_tensor.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size; |
| 135 | + |
| 136 | + uint8_t* scale_ue8m08sf_ptr = scale_tensor.data_ptr<uint8_t>(); |
| 137 | + |
| 138 | + float local_amax = 0.0f; |
| 139 | + for (int ki = 0; ki < sf_vec_size; ++ki) { |
| 140 | + local_amax = std::max(std::abs(fp32_ptr[ki]), local_amax); |
| 141 | + } |
| 142 | + |
| 143 | + local_amax *= (1.f / 448.0f); |
| 144 | + |
| 145 | + uint8_t scale_ue8m0 = float_to_ue8m0(local_amax); |
| 146 | + auto const inv_scale = (scale_ue8m0 == 0) ? 1 : exp2f(127 - static_cast<float>(scale_ue8m0)); |
| 147 | + |
| 148 | + scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0], groups_per_hidden_dim, layout)] = |
| 149 | + scale_ue8m0; |
| 150 | + |
| 151 | + for (int ki = 0; ki < sf_vec_size; ++ki) { |
| 152 | + float const scaled_fp32_value = fp32_ptr[ki] * inv_scale; |
| 153 | + auto fp8_value = cutlass::float_e4m3_t{scaled_fp32_value}; |
| 154 | + fp8_ptr[ki] = *reinterpret_cast<uint8_t*>(&fp8_value); |
| 155 | + } |
| 156 | + } |
| 157 | + } |
| 158 | + return std::make_tuple(fp8_tensor, scale_tensor); |
| 159 | +} |
| 160 | + |
| 161 | +// Used in tests to dequantize mxe4m3 tensors on host. |
| 162 | +at::Tensor mxfp8_dequantize_host(at::Tensor value_e4m3, at::Tensor scale_ue8m08sf, |
| 163 | + bool is_sf_swizzled_layout) { |
| 164 | + int32_t const sf_vec_size = 32; |
| 165 | + CHECK_CPU_INPUT(value_e4m3, c10::ScalarType::Byte); |
| 166 | + CHECK_CPU_INPUT(scale_ue8m08sf, SF_DTYPE); |
| 167 | + auto data_shape = value_e4m3.sizes(); |
| 168 | + auto scale_shape = scale_ue8m08sf.sizes(); |
| 169 | + TORCH_CHECK(data_shape.size() == 2, "value_e4m3 should be 2D tensor."); |
| 170 | + TORCH_CHECK(scale_shape.size() == 1, "scale_ue8m08sf should be 1D tensor."); |
| 171 | + at::Tensor float_tensor = |
| 172 | + at::detail::empty_cpu({data_shape[0], data_shape[1]}, at::ScalarType::Float, |
| 173 | + /* pinned */ true, at::MemoryFormat::Contiguous); |
| 174 | + |
| 175 | + int hidden_dim = data_shape[1]; |
| 176 | + int groups_per_hidden_dim = hidden_dim / sf_vec_size; |
| 177 | + |
| 178 | + tensorrt_llm::FP4QuantizationSFLayout layout = |
| 179 | + is_sf_swizzled_layout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4 |
| 180 | + : tensorrt_llm::FP4QuantizationSFLayout::LINEAR; |
| 181 | + for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) { |
| 182 | + for (int group = 0; group < groups_per_hidden_dim; ++group) { |
| 183 | + float* float_ptr = float_tensor.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size; |
| 184 | + uint8_t* fp8_ptr = value_e4m3.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size; |
| 185 | + uint8_t* scale_ue8m08sf_ptr = scale_ue8m08sf.data_ptr<uint8_t>(); |
| 186 | + uint8_t fp8_scale = scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0], |
| 187 | + groups_per_hidden_dim, layout)]; |
| 188 | + |
| 189 | + float scale_float; |
| 190 | + uint32_t scale_float_u32 = uint32_t(fp8_scale) << 23; |
| 191 | + memcpy(&scale_float, &scale_float_u32, sizeof(scale_float)); |
| 192 | + |
| 193 | + for (int ki = 0; ki < sf_vec_size; ++ki) { |
| 194 | + uint8_t fp8_u8_repr = fp8_ptr[ki]; |
| 195 | + auto fp32 = static_cast<float>(*reinterpret_cast<cutlass::float_e4m3_t*>(&fp8_u8_repr)); |
| 196 | + float value = fp32 * scale_float; |
| 197 | + float_ptr[ki] = value; |
| 198 | + } |
| 199 | + } |
| 200 | + } |
| 201 | + return float_tensor; |
| 202 | +} |
| 203 | +} // namespace torch_ext |
| 204 | + |
| 205 | +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { |
| 206 | + m.def("mxfp8_dequantize_host", &torch_ext::mxfp8_dequantize_host); |
| 207 | + m.def("mxfp8_quantize_host", &torch_ext::mxfp8_quantize_host); |
| 208 | + m.def("mxfp8_quantize", &torch_ext::mxfp8_quantize); |
| 209 | +} |
0 commit comments