Skip to content

Commit b044dff

Browse files
IwakuraReinjoker-eph
authored andcommitted
Add Blackwell MoE mxfp4 implementation from TRTLLM and Attention Sink support
1 parent 9158fef commit b044dff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+25240
-6535
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,41 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const
7070

7171
////////////////////////////////////////////////////////////////////////////////////////////////////
7272

73+
////////////////////////////////////////////////////////////////////////////////////////////////////
74+
// MXFP8 Quantization
75+
76+
template <typename T>
77+
void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* output, int32_t* SFOuput,
78+
FP4QuantizationSFLayout layout, int multiProcessorCount,
79+
cudaStream_t stream) {
80+
// Fixed SF_VEC_SIZE as 32
81+
static constexpr int SF_VEC_SIZE = 32;
82+
83+
// Grid, Block size.
84+
// Each thread converts 8 values.
85+
dim3 block(std::min(int(n / CVT_FP4_ELTS_PER_THREAD), 512));
86+
// Get number of blocks per SM (assume we can fully utilize the SM).
87+
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
88+
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
89+
90+
// Launch the cvt kernel.
91+
cudaLaunchConfig_t config;
92+
config.gridDim = grid;
93+
config.blockDim = block;
94+
config.dynamicSmemBytes = 0;
95+
config.stream = stream;
96+
cudaLaunchAttribute attrs[1];
97+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
98+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
99+
config.numAttrs = 1;
100+
config.attrs = attrs;
101+
cudaLaunchKernelEx(
102+
&config,
103+
quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
104+
m, n, input, nullptr, reinterpret_cast<uint32_t*>(output),
105+
reinterpret_cast<uint32_t*>(SFOuput), layout);
106+
}
107+
73108
// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
74109
template <typename T, typename QuantT>
75110
void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows,
@@ -320,6 +355,9 @@ template void invokeBatchedFP4Quantization<half, 16>(
320355
template void invokeBatchedFP4Quantization<half, 32>(
321356
int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput,
322357
bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream);
358+
template void invokeMxFP8Quantization<half>(int b, int m, int n, half const* input, int64_t* output,
359+
int32_t* SFOuput, FP4QuantizationSFLayout layout,
360+
int multiProcessorCount, cudaStream_t stream);
323361
#ifdef ENABLE_BF16
324362
template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input,
325363
float const* SFScale, int64_t* output,
@@ -341,6 +379,12 @@ template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>(
341379
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
342380
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
343381
cudaStream_t stream);
382+
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n,
383+
__nv_bfloat16 const* input, int64_t* output,
384+
int32_t* SFOuput,
385+
FP4QuantizationSFLayout layout,
386+
int multiProcessorCount, cudaStream_t stream);
387+
344388
#endif
345389

346390
#ifdef ENABLE_FP8

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,91 @@ cvt_fp8_to_fp4(
917917
#endif
918918
}
919919

920+
template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
921+
__global__ void
922+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
923+
__launch_bounds__(512, 4) quantize_with_block_size(
924+
#else
925+
quantize_with_block_size(
926+
#endif
927+
int32_t numbatches, int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
928+
uint32_t* out, uint32_t* SFout, FP4QuantizationSFLayout layout) {
929+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
930+
931+
// The elements per thread.
932+
static constexpr int ELTS_PER_THREAD = quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
933+
? CVT_FP8_TO_FP4_ELTS_PER_THREAD
934+
: CVT_FP4_ELTS_PER_THREAD;
935+
936+
using PackedVec = PackedVec<Type>;
937+
static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD; // 2 or 4
938+
static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched.");
939+
940+
// Get the global scaling factor, which will be applied to the SF.
941+
// Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)).
942+
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
943+
944+
int numPaddedRows = numRows;
945+
int numPaddedCols = numCols;
946+
if (layout == FP4QuantizationSFLayout::SWIZZLED_128x4) {
947+
// The number of padded rows considering 128x4 SF layout.
948+
numPaddedRows = PadUpFn(numRows, 128);
949+
numPaddedCols = PadUpFn(numCols, 4 * SF_VEC_SIZE);
950+
} else if (layout == FP4QuantizationSFLayout::SWIZZLED_8x4) {
951+
// The number of padded rows considering 8x4 SF layout.
952+
numPaddedRows = PadUpFn(numRows, 8);
953+
numPaddedCols = PadUpFn(numCols, 4 * SF_VEC_SIZE);
954+
}
955+
956+
// The number of threads in the column dimension
957+
int numColThreads = numCols / ELTS_PER_THREAD;
958+
int numPaddedColThreads = numPaddedCols / ELTS_PER_THREAD;
959+
960+
asm volatile("griddepcontrol.wait;");
961+
// Input tensor batch/row/col loops.
962+
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRows; rowIdx += gridDim.x) {
963+
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
964+
for (int colIdx = threadIdx.x; colIdx < numPaddedColThreads; colIdx += blockDim.x) {
965+
std::optional<int> optionalBatchIdx = batchIdx;
966+
std::optional<int> optionalNumRows = numRows;
967+
968+
// The SF output pointer.
969+
auto sf_out =
970+
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF, SF_VEC_SIZE>(
971+
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numCols, SFout, layout);
972+
973+
// Set the SF padding to 0.
974+
if (rowIdx >= numRows || colIdx >= numColThreads) {
975+
if (sf_out != nullptr) {
976+
sf_out[0] = 0x00;
977+
}
978+
} else {
979+
int64_t inOffset =
980+
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
981+
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
982+
// Get the output tensor offset as a packed vector.
983+
int64_t outOffset = inOffset;
984+
985+
// Dispatch the quantization kernel.
986+
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
987+
reinterpret_cast<uint32_t*>(out)[outOffset] =
988+
cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
989+
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
990+
reinterpret_cast<uint64_t*>(out)[outOffset] =
991+
cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
992+
sf_out);
993+
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
994+
reinterpret_cast<uint64_t*>(out)[outOffset] =
995+
cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
996+
}
997+
}
998+
}
999+
}
1000+
}
1001+
asm volatile("griddepcontrol.launch_dependents;");
1002+
#endif
1003+
}
1004+
9201005
__global__ void nvfp4_block_scale_interleave_kernel(int numbatches, int numRows, int numCols,
9211006
uint8_t const* SFIn, uint8_t* SFOutput);
9221007
} // namespace kernels

csrc/nv_internal/tensorrt_llm/kernels/quantization.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ enum class FP4QuantizationSFLayout {
4141
LINEAR
4242
};
4343

44+
// This denotes the input and output data types of the block scale quantization.
45+
enum class BlockScaleQuantizationType {
46+
FP16_TO_FP4 = 0,
47+
FP8_TO_FP4 = 1,
48+
FP16_TO_MXFP8 = 2,
49+
};
50+
4451
#define PadUpFn(X, Y) ((X + Y - 1) / (Y) * (Y))
4552

4653
// totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed.
@@ -86,5 +93,10 @@ void invokeNVFP4BlockScaleInterleaveReverse(int b, int m, int n, uint8_t const*
8693
uint8_t* SFOutput, int multiProcessorCount,
8794
cudaStream_t stream = 0);
8895

96+
template <typename T>
97+
void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* output, int32_t* SFOuput,
98+
FP4QuantizationSFLayout layout, int multiProcessorCount,
99+
cudaStream_t stream = 0);
100+
89101
} // namespace kernels
90102
} // namespace tensorrt_llm
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/thop/fp8Quantize.h"
18+
19+
#include <ATen/cuda/EmptyTensor.h>
20+
21+
#include "cutlass/numeric_types.h"
22+
#include "pytorch_extension_utils.h"
23+
#include "tensorrt_llm/thop/thUtils.h"
24+
25+
namespace torch_ext {
26+
27+
// self: [M, K], fp16/bf16/fp8_quantized
28+
// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in
29+
// linear layout. See FP4QuantizationSFLayout enum for more details about the two layouts.
30+
// returns
31+
// self_fp8, self_block_scale_factors
32+
// self_fp8: [M, K], uint8_t
33+
// self_block_scale_factors: ceil(M / 128) * 128 * ceil(K / 32 / 4) * 4, uint8_t
34+
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor self, bool isSfSwizzledLayout) {
35+
CHECK_TH_CUDA(self);
36+
CHECK_CONTIGUOUS(self);
37+
38+
auto const& inputShape = self.sizes();
39+
auto const& rank = inputShape.size();
40+
41+
TORCH_CHECK(rank >= 2, "Input should be >=2D tensor.");
42+
int64_t m = 1;
43+
for (size_t i = 0; i < rank - 1; i++) {
44+
m *= inputShape[i];
45+
}
46+
auto const k = inputShape[rank - 1];
47+
int32_t const sfVecSize = 32;
48+
TORCH_CHECK(k % sfVecSize == 0);
49+
50+
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
51+
outputShape[rank - 1] = k;
52+
53+
at::Tensor valueFP8 =
54+
at::detail::empty_cuda(outputShape, at::ScalarType::Float8_e4m3fn, self.device(),
55+
/* stride */ std::nullopt);
56+
57+
int64_t SFSize = isSfSwizzledLayout
58+
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)
59+
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize);
60+
61+
at::Tensor scaleFP8SF = at::detail::empty_cuda({SFSize}, SF_DTYPE, self.device(),
62+
/* stride */ std::nullopt); // 1D tensor
63+
64+
const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
65+
66+
auto const layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
67+
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
68+
69+
#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \
70+
tensorrt_llm::kernels::invokeMxFP8Quantization<T>( \
71+
1, m, k, reinterpret_cast<T*>(self.data_ptr()), \
72+
reinterpret_cast<int64_t*>(valueFP8.data_ptr()), \
73+
reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, \
74+
at::cuda::getCurrentCUDAStream(self.get_device()));
75+
76+
if (self.scalar_type() == at::ScalarType::Half) {
77+
LAUNCH_MXFP8_QUANTIZE_KERNEL(half)
78+
} else if (self.scalar_type() == at::ScalarType::BFloat16) {
79+
#ifdef ENABLE_BF16
80+
LAUNCH_MXFP8_QUANTIZE_KERNEL(__nv_bfloat16)
81+
#else
82+
C10_THROW_ERROR(NotImplementedError,
83+
"BFloat16 must be enabled to quantize an bf16 tensor to mxfp8.");
84+
#endif
85+
} else {
86+
C10_THROW_ERROR(NotImplementedError,
87+
"mxfp8_quantize only supports input tensor with dtypes fp16/bf16.");
88+
}
89+
90+
#undef LAUNCH_MXFP8_QUANTIZE_KERNEL
91+
92+
return {valueFP8, scaleFP8SF};
93+
}
94+
95+
inline uint8_t float_to_ue8m0(float value) {
96+
if (value == 0.0f) {
97+
return 0x00;
98+
}
99+
constexpr uint32_t FP32_MANTISSA_BITS = 23;
100+
uint32_t val_u32 = *reinterpret_cast<uint32_t*>(&value);
101+
uint8_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
102+
uint32_t mantissa = val_u32 & 0x7FFFFF;
103+
// Round up exponent and deal with satfinite.
104+
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
105+
++exponent;
106+
}
107+
return exponent;
108+
}
109+
110+
// Used in tests to quantize mxe4m3 tensors on host.
111+
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize_host(at::Tensor x_fp32,
112+
bool is_sf_swizzled_layout) {
113+
int32_t const sf_vec_size = 32;
114+
CHECK_CPU_INPUT(x_fp32, c10::ScalarType::Float);
115+
auto data_shape = x_fp32.sizes();
116+
TORCH_CHECK(data_shape.size() == 2, "x_fp32 should be 2D tensor.");
117+
int num_tokens = data_shape[0];
118+
int hidden_dim = data_shape[1];
119+
int groups_per_hidden_dim = hidden_dim / sf_vec_size;
120+
121+
at::Tensor fp8_tensor = at::detail::empty_cpu({num_tokens, hidden_dim}, at::ScalarType::Byte,
122+
/* pinned */ true, at::MemoryFormat::Contiguous);
123+
int64_t sf_size =
124+
is_sf_swizzled_layout
125+
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(num_tokens, hidden_dim / sf_vec_size)
126+
: tensorrt_llm::computeFP4LinearLayoutSFSize(num_tokens, hidden_dim / sf_vec_size);
127+
at::Tensor scale_tensor =
128+
at::detail::empty_cpu({sf_size}, SF_DTYPE, /* pinned */ true, at::MemoryFormat::Contiguous);
129+
130+
tensorrt_llm::FP4QuantizationSFLayout layout =
131+
is_sf_swizzled_layout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
132+
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
133+
134+
for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) {
135+
for (int group = 0; group < groups_per_hidden_dim; ++group) {
136+
float* fp32_ptr = x_fp32.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size;
137+
uint8_t* fp8_ptr = fp8_tensor.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size;
138+
139+
uint8_t* scale_ue8m08sf_ptr = scale_tensor.data_ptr<uint8_t>();
140+
141+
float local_amax = 0.0f;
142+
for (int ki = 0; ki < sf_vec_size; ++ki) {
143+
local_amax = std::max(std::abs(fp32_ptr[ki]), local_amax);
144+
}
145+
146+
local_amax *= (1.f / 448.0f);
147+
148+
uint8_t scale_ue8m0 = float_to_ue8m0(local_amax);
149+
auto const inv_scale = (scale_ue8m0 == 0) ? 1 : exp2f(127 - static_cast<float>(scale_ue8m0));
150+
151+
scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0], groups_per_hidden_dim, layout)] =
152+
scale_ue8m0;
153+
154+
for (int ki = 0; ki < sf_vec_size; ++ki) {
155+
float const scaled_fp32_value = fp32_ptr[ki] * inv_scale;
156+
auto fp8_value = cutlass::float_e4m3_t{scaled_fp32_value};
157+
fp8_ptr[ki] = *reinterpret_cast<uint8_t*>(&fp8_value);
158+
}
159+
}
160+
}
161+
return std::make_tuple(fp8_tensor, scale_tensor);
162+
}
163+
164+
// Used in tests to dequantize mxe4m3 tensors on host.
165+
at::Tensor mxfp8_dequantize_host(at::Tensor value_e4m3, at::Tensor scale_ue8m08sf,
166+
bool is_sf_swizzled_layout) {
167+
int32_t const sf_vec_size = 32;
168+
CHECK_CPU_INPUT(value_e4m3, c10::ScalarType::Byte);
169+
CHECK_CPU_INPUT(scale_ue8m08sf, SF_DTYPE);
170+
auto data_shape = value_e4m3.sizes();
171+
auto scale_shape = scale_ue8m08sf.sizes();
172+
TORCH_CHECK(data_shape.size() == 2, "value_e4m3 should be 2D tensor.");
173+
TORCH_CHECK(scale_shape.size() == 1, "scale_ue8m08sf should be 1D tensor.");
174+
at::Tensor float_tensor =
175+
at::detail::empty_cpu({data_shape[0], data_shape[1]}, at::ScalarType::Float,
176+
/* pinned */ true, at::MemoryFormat::Contiguous);
177+
178+
int hidden_dim = data_shape[1];
179+
int groups_per_hidden_dim = hidden_dim / sf_vec_size;
180+
181+
tensorrt_llm::FP4QuantizationSFLayout layout =
182+
is_sf_swizzled_layout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
183+
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
184+
for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) {
185+
for (int group = 0; group < groups_per_hidden_dim; ++group) {
186+
float* float_ptr = float_tensor.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size;
187+
uint8_t* fp8_ptr = value_e4m3.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size;
188+
uint8_t* scale_ue8m08sf_ptr = scale_ue8m08sf.data_ptr<uint8_t>();
189+
uint8_t fp8_scale = scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0],
190+
groups_per_hidden_dim, layout)];
191+
192+
float scale_float;
193+
uint32_t scale_float_u32 = uint32_t(fp8_scale) << 23;
194+
memcpy(&scale_float, &scale_float_u32, sizeof(scale_float));
195+
196+
for (int ki = 0; ki < sf_vec_size; ++ki) {
197+
uint8_t fp8_u8_repr = fp8_ptr[ki];
198+
auto fp32 = static_cast<float>(*reinterpret_cast<cutlass::float_e4m3_t*>(&fp8_u8_repr));
199+
float value = fp32 * scale_float;
200+
float_ptr[ki] = value;
201+
}
202+
}
203+
}
204+
return float_tensor;
205+
}
206+
} // namespace torch_ext
207+
208+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
209+
m.def("mxfp8_dequantize_host", &torch_ext::mxfp8_dequantize_host);
210+
m.def("mxfp8_quantize_host", &torch_ext::mxfp8_quantize_host);
211+
m.def("mxfp8_quantize", &torch_ext::mxfp8_quantize);
212+
}

0 commit comments

Comments
 (0)