Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions csrc/nv_internal/cpp/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,41 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const

////////////////////////////////////////////////////////////////////////////////////////////////////

////////////////////////////////////////////////////////////////////////////////////////////////////
// MXFP8 Quantization

template <typename T>
void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* output, int32_t* SFOuput,
FP4QuantizationSFLayout layout, int multiProcessorCount,
cudaStream_t stream) {
// Fixed SF_VEC_SIZE as 32
static constexpr int SF_VEC_SIZE = 32;

// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / CVT_FP4_ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));

// Launch the cvt kernel.
cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config,
quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
m, n, input, nullptr, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), layout);
}

// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
template <typename T, typename QuantT>
void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows,
Expand Down Expand Up @@ -320,6 +355,9 @@ template void invokeBatchedFP4Quantization<half, 16>(
template void invokeBatchedFP4Quantization<half, 32>(
int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput,
bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream);
template void invokeMxFP8Quantization<half>(int b, int m, int n, half const* input, int64_t* output,
int32_t* SFOuput, FP4QuantizationSFLayout layout,
int multiProcessorCount, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input,
float const* SFScale, int64_t* output,
Expand All @@ -341,6 +379,12 @@ template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>(
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
cudaStream_t stream);
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n,
__nv_bfloat16 const* input, int64_t* output,
int32_t* SFOuput,
FP4QuantizationSFLayout layout,
int multiProcessorCount, cudaStream_t stream);

#endif

#ifdef ENABLE_FP8
Expand Down
85 changes: 85 additions & 0 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,91 @@ cvt_fp8_to_fp4(
#endif
}

template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) quantize_with_block_size(
#else
quantize_with_block_size(
#endif
int32_t numbatches, int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout, FP4QuantizationSFLayout layout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)

// The elements per thread.
static constexpr int ELTS_PER_THREAD = quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
? CVT_FP8_TO_FP4_ELTS_PER_THREAD
: CVT_FP4_ELTS_PER_THREAD;

using PackedVec = PackedVec<Type>;
static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD; // 2 or 4
static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched.");

// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];

int numPaddedRows = numRows;
int numPaddedCols = numCols;
if (layout == FP4QuantizationSFLayout::SWIZZLED_128x4) {
// The number of padded rows considering 128x4 SF layout.
numPaddedRows = PadUpFn(numRows, 128);
numPaddedCols = PadUpFn(numCols, 4 * SF_VEC_SIZE);
} else if (layout == FP4QuantizationSFLayout::SWIZZLED_8x4) {
// The number of padded rows considering 8x4 SF layout.
numPaddedRows = PadUpFn(numRows, 8);
numPaddedCols = PadUpFn(numCols, 4 * SF_VEC_SIZE);
}

// The number of threads in the column dimension
int numColThreads = numCols / ELTS_PER_THREAD;
int numPaddedColThreads = numPaddedCols / ELTS_PER_THREAD;

asm volatile("griddepcontrol.wait;");
// Input tensor batch/row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRows; rowIdx += gridDim.x) {
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
for (int colIdx = threadIdx.x; colIdx < numPaddedColThreads; colIdx += blockDim.x) {
std::optional<int> optionalBatchIdx = batchIdx;
std::optional<int> optionalNumRows = numRows;

// The SF output pointer.
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF, SF_VEC_SIZE>(
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numCols, SFout, layout);

// Set the SF padding to 0.
if (rowIdx >= numRows || colIdx >= numColThreads) {
if (sf_out != nullptr) {
sf_out[0] = 0x00;
}
} else {
int64_t inOffset =
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset as a packed vector.
int64_t outOffset = inOffset;

// Dispatch the quantization kernel.
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
reinterpret_cast<uint32_t*>(out)[outOffset] =
cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
reinterpret_cast<uint64_t*>(out)[outOffset] =
cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
sf_out);
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
reinterpret_cast<uint64_t*>(out)[outOffset] =
cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
}
}
}
}
}
asm volatile("griddepcontrol.launch_dependents;");
#endif
}

__global__ void nvfp4_block_scale_interleave_kernel(int numbatches, int numRows, int numCols,
uint8_t const* SFIn, uint8_t* SFOutput);
} // namespace kernels
Expand Down
12 changes: 12 additions & 0 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ enum class FP4QuantizationSFLayout {
LINEAR
};

// This denotes the input and output data types of the block scale quantization.
enum class BlockScaleQuantizationType {
FP16_TO_FP4 = 0,
FP8_TO_FP4 = 1,
FP16_TO_MXFP8 = 2,
};

#define PadUpFn(X, Y) ((X + Y - 1) / (Y) * (Y))

// totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed.
Expand Down Expand Up @@ -86,5 +93,10 @@ void invokeNVFP4BlockScaleInterleaveReverse(int b, int m, int n, uint8_t const*
uint8_t* SFOutput, int multiProcessorCount,
cudaStream_t stream = 0);

template <typename T>
void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* output, int32_t* SFOuput,
FP4QuantizationSFLayout layout, int multiProcessorCount,
cudaStream_t stream = 0);

} // namespace kernels
} // namespace tensorrt_llm
212 changes: 212 additions & 0 deletions csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/thop/fp8Quantize.h"

#include <ATen/cuda/EmptyTensor.h>

#include "cutlass/numeric_types.h"
#include "pytorch_extension_utils.h"
#include "tensorrt_llm/thop/thUtils.h"

namespace torch_ext {

// self: [M, K], fp16/bf16/fp8_quantized
// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in
// linear layout. See FP4QuantizationSFLayout enum for more details about the two layouts.
// returns
// self_fp8, self_block_scale_factors
// self_fp8: [M, K], uint8_t
// self_block_scale_factors: ceil(M / 128) * 128 * ceil(K / 32 / 4) * 4, uint8_t
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor self, bool isSfSwizzledLayout) {
CHECK_TH_CUDA(self);
CHECK_CONTIGUOUS(self);

auto const& inputShape = self.sizes();
auto const& rank = inputShape.size();

TORCH_CHECK(rank >= 2, "Input should be >=2D tensor.");
int64_t m = 1;
for (size_t i = 0; i < rank - 1; i++) {
m *= inputShape[i];
}
auto const k = inputShape[rank - 1];
int32_t const sfVecSize = 32;
TORCH_CHECK(k % sfVecSize == 0);

std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
outputShape[rank - 1] = k;

at::Tensor valueFP8 =
at::detail::empty_cuda(outputShape, at::ScalarType::Float8_e4m3fn, self.device(),
/* stride */ std::nullopt);

int64_t SFSize = isSfSwizzledLayout
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize);

at::Tensor scaleFP8SF = at::detail::empty_cuda({SFSize}, SF_DTYPE, self.device(),
/* stride */ std::nullopt); // 1D tensor

const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();

auto const layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;

#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \
tensorrt_llm::kernels::invokeMxFP8Quantization<T>( \
1, m, k, reinterpret_cast<T*>(self.data_ptr()), \
reinterpret_cast<int64_t*>(valueFP8.data_ptr()), \
reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, \
at::cuda::getCurrentCUDAStream(self.get_device()));

if (self.scalar_type() == at::ScalarType::Half) {
LAUNCH_MXFP8_QUANTIZE_KERNEL(half)
} else if (self.scalar_type() == at::ScalarType::BFloat16) {
#ifdef ENABLE_BF16
LAUNCH_MXFP8_QUANTIZE_KERNEL(__nv_bfloat16)
#else
C10_THROW_ERROR(NotImplementedError,
"BFloat16 must be enabled to quantize an bf16 tensor to mxfp8.");
#endif
} else {
C10_THROW_ERROR(NotImplementedError,
"mxfp8_quantize only supports input tensor with dtypes fp16/bf16.");
}

#undef LAUNCH_MXFP8_QUANTIZE_KERNEL

return {valueFP8, scaleFP8SF};
}

inline uint8_t float_to_ue8m0(float value) {
if (value == 0.0f) {
return 0x00;
}
constexpr uint32_t FP32_MANTISSA_BITS = 23;
uint32_t val_u32 = *reinterpret_cast<uint32_t*>(&value);
uint8_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
}

// Used in tests to quantize mxe4m3 tensors on host.
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize_host(at::Tensor x_fp32,
bool is_sf_swizzled_layout) {
int32_t const sf_vec_size = 32;
CHECK_CPU_INPUT(x_fp32, c10::ScalarType::Float);
auto data_shape = x_fp32.sizes();
TORCH_CHECK(data_shape.size() == 2, "x_fp32 should be 2D tensor.");
int num_tokens = data_shape[0];
int hidden_dim = data_shape[1];
int groups_per_hidden_dim = hidden_dim / sf_vec_size;

at::Tensor fp8_tensor = at::detail::empty_cpu({num_tokens, hidden_dim}, at::ScalarType::Byte,
/* pinned */ true, at::MemoryFormat::Contiguous);
int64_t sf_size =
is_sf_swizzled_layout
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(num_tokens, hidden_dim / sf_vec_size)
: tensorrt_llm::computeFP4LinearLayoutSFSize(num_tokens, hidden_dim / sf_vec_size);
at::Tensor scale_tensor =
at::detail::empty_cpu({sf_size}, SF_DTYPE, /* pinned */ true, at::MemoryFormat::Contiguous);

tensorrt_llm::FP4QuantizationSFLayout layout =
is_sf_swizzled_layout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;

for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) {
for (int group = 0; group < groups_per_hidden_dim; ++group) {
float* fp32_ptr = x_fp32.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size;
uint8_t* fp8_ptr = fp8_tensor.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size;

uint8_t* scale_ue8m08sf_ptr = scale_tensor.data_ptr<uint8_t>();

float local_amax = 0.0f;
for (int ki = 0; ki < sf_vec_size; ++ki) {
local_amax = std::max(std::abs(fp32_ptr[ki]), local_amax);
}

local_amax *= (1.f / 448.0f);

uint8_t scale_ue8m0 = float_to_ue8m0(local_amax);
auto const inv_scale = (scale_ue8m0 == 0) ? 1 : exp2f(127 - static_cast<float>(scale_ue8m0));

scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0], groups_per_hidden_dim, layout)] =
scale_ue8m0;

for (int ki = 0; ki < sf_vec_size; ++ki) {
float const scaled_fp32_value = fp32_ptr[ki] * inv_scale;
auto fp8_value = cutlass::float_e4m3_t{scaled_fp32_value};
fp8_ptr[ki] = *reinterpret_cast<uint8_t*>(&fp8_value);
}
}
}
return std::make_tuple(fp8_tensor, scale_tensor);
}

// Used in tests to dequantize mxe4m3 tensors on host.
at::Tensor mxfp8_dequantize_host(at::Tensor value_e4m3, at::Tensor scale_ue8m08sf,
bool is_sf_swizzled_layout) {
int32_t const sf_vec_size = 32;
CHECK_CPU_INPUT(value_e4m3, c10::ScalarType::Byte);
CHECK_CPU_INPUT(scale_ue8m08sf, SF_DTYPE);
auto data_shape = value_e4m3.sizes();
auto scale_shape = scale_ue8m08sf.sizes();
TORCH_CHECK(data_shape.size() == 2, "value_e4m3 should be 2D tensor.");
TORCH_CHECK(scale_shape.size() == 1, "scale_ue8m08sf should be 1D tensor.");
at::Tensor float_tensor =
at::detail::empty_cpu({data_shape[0], data_shape[1]}, at::ScalarType::Float,
/* pinned */ true, at::MemoryFormat::Contiguous);

int hidden_dim = data_shape[1];
int groups_per_hidden_dim = hidden_dim / sf_vec_size;

tensorrt_llm::FP4QuantizationSFLayout layout =
is_sf_swizzled_layout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) {
for (int group = 0; group < groups_per_hidden_dim; ++group) {
float* float_ptr = float_tensor.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size;
uint8_t* fp8_ptr = value_e4m3.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size;
uint8_t* scale_ue8m08sf_ptr = scale_ue8m08sf.data_ptr<uint8_t>();
uint8_t fp8_scale = scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0],
groups_per_hidden_dim, layout)];

float scale_float;
uint32_t scale_float_u32 = uint32_t(fp8_scale) << 23;
memcpy(&scale_float, &scale_float_u32, sizeof(scale_float));

for (int ki = 0; ki < sf_vec_size; ++ki) {
uint8_t fp8_u8_repr = fp8_ptr[ki];
auto fp32 = static_cast<float>(*reinterpret_cast<cutlass::float_e4m3_t*>(&fp8_u8_repr));
float value = fp32 * scale_float;
float_ptr[ki] = value;
}
}
}
return float_tensor;
}
} // namespace torch_ext

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("mxfp8_dequantize_host", &torch_ext::mxfp8_dequantize_host);
m.def("mxfp8_quantize_host", &torch_ext::mxfp8_quantize_host);
m.def("mxfp8_quantize", &torch_ext::mxfp8_quantize);
}
Loading