Skip to content

Commit 65e9092

Browse files
GPT-OSS Support: Add Blackwell MoE mxfp4 implementation from TRTLLM and Attention Sink (#1389)
These kernels support [OpenAI GPT-OSS](https://openai.com/index/introducing-gpt-oss/) Co-authored-by: siyuanf <[email protected]> Co-authored-by: Zihao Ye <[email protected]> Co-authored-by: Qidi Sang <[email protected]> --------- Co-authored-by: siyuanf <[email protected]>
1 parent 9158fef commit 65e9092

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+25275
-6562
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,41 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const
7070

7171
////////////////////////////////////////////////////////////////////////////////////////////////////
7272

73+
////////////////////////////////////////////////////////////////////////////////////////////////////
74+
// MXFP8 Quantization
75+
76+
template <typename T>
77+
void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* output, int32_t* SFOuput,
78+
FP4QuantizationSFLayout layout, int multiProcessorCount,
79+
cudaStream_t stream) {
80+
// Fixed SF_VEC_SIZE as 32
81+
static constexpr int SF_VEC_SIZE = 32;
82+
83+
// Grid, Block size.
84+
// Each thread converts 8 values.
85+
dim3 block(std::min(int(n / CVT_FP4_ELTS_PER_THREAD), 512));
86+
// Get number of blocks per SM (assume we can fully utilize the SM).
87+
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
88+
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
89+
90+
// Launch the cvt kernel.
91+
cudaLaunchConfig_t config;
92+
config.gridDim = grid;
93+
config.blockDim = block;
94+
config.dynamicSmemBytes = 0;
95+
config.stream = stream;
96+
cudaLaunchAttribute attrs[1];
97+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
98+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
99+
config.numAttrs = 1;
100+
config.attrs = attrs;
101+
cudaLaunchKernelEx(
102+
&config,
103+
quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
104+
m, n, input, nullptr, reinterpret_cast<uint32_t*>(output),
105+
reinterpret_cast<uint32_t*>(SFOuput), layout);
106+
}
107+
73108
// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
74109
template <typename T, typename QuantT>
75110
void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows,
@@ -320,6 +355,9 @@ template void invokeBatchedFP4Quantization<half, 16>(
320355
template void invokeBatchedFP4Quantization<half, 32>(
321356
int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput,
322357
bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream);
358+
template void invokeMxFP8Quantization<half>(int b, int m, int n, half const* input, int64_t* output,
359+
int32_t* SFOuput, FP4QuantizationSFLayout layout,
360+
int multiProcessorCount, cudaStream_t stream);
323361
#ifdef ENABLE_BF16
324362
template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input,
325363
float const* SFScale, int64_t* output,
@@ -341,6 +379,12 @@ template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>(
341379
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
342380
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
343381
cudaStream_t stream);
382+
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n,
383+
__nv_bfloat16 const* input, int64_t* output,
384+
int32_t* SFOuput,
385+
FP4QuantizationSFLayout layout,
386+
int multiProcessorCount, cudaStream_t stream);
387+
344388
#endif
345389

346390
#ifdef ENABLE_FP8

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,91 @@ cvt_fp8_to_fp4(
917917
#endif
918918
}
919919

920+
template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
921+
__global__ void
922+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
923+
__launch_bounds__(512, 4) quantize_with_block_size(
924+
#else
925+
quantize_with_block_size(
926+
#endif
927+
int32_t numbatches, int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
928+
uint32_t* out, uint32_t* SFout, FP4QuantizationSFLayout layout) {
929+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
930+
931+
// The elements per thread.
932+
static constexpr int ELTS_PER_THREAD = quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
933+
? CVT_FP8_TO_FP4_ELTS_PER_THREAD
934+
: CVT_FP4_ELTS_PER_THREAD;
935+
936+
using PackedVec = PackedVec<Type>;
937+
static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD; // 2 or 4
938+
static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched.");
939+
940+
// Get the global scaling factor, which will be applied to the SF.
941+
// Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)).
942+
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
943+
944+
int numPaddedRows = numRows;
945+
int numPaddedCols = numCols;
946+
if (layout == FP4QuantizationSFLayout::SWIZZLED_128x4) {
947+
// The number of padded rows considering 128x4 SF layout.
948+
numPaddedRows = PadUpFn(numRows, 128);
949+
numPaddedCols = PadUpFn(numCols, 4 * SF_VEC_SIZE);
950+
} else if (layout == FP4QuantizationSFLayout::SWIZZLED_8x4) {
951+
// The number of padded rows considering 8x4 SF layout.
952+
numPaddedRows = PadUpFn(numRows, 8);
953+
numPaddedCols = PadUpFn(numCols, 4 * SF_VEC_SIZE);
954+
}
955+
956+
// The number of threads in the column dimension
957+
int numColThreads = numCols / ELTS_PER_THREAD;
958+
int numPaddedColThreads = numPaddedCols / ELTS_PER_THREAD;
959+
960+
asm volatile("griddepcontrol.wait;");
961+
// Input tensor batch/row/col loops.
962+
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRows; rowIdx += gridDim.x) {
963+
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
964+
for (int colIdx = threadIdx.x; colIdx < numPaddedColThreads; colIdx += blockDim.x) {
965+
std::optional<int> optionalBatchIdx = batchIdx;
966+
std::optional<int> optionalNumRows = numRows;
967+
968+
// The SF output pointer.
969+
auto sf_out =
970+
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF, SF_VEC_SIZE>(
971+
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numCols, SFout, layout);
972+
973+
// Set the SF padding to 0.
974+
if (rowIdx >= numRows || colIdx >= numColThreads) {
975+
if (sf_out != nullptr) {
976+
sf_out[0] = 0x00;
977+
}
978+
} else {
979+
int64_t inOffset =
980+
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
981+
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
982+
// Get the output tensor offset as a packed vector.
983+
int64_t outOffset = inOffset;
984+
985+
// Dispatch the quantization kernel.
986+
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
987+
reinterpret_cast<uint32_t*>(out)[outOffset] =
988+
cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
989+
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
990+
reinterpret_cast<uint64_t*>(out)[outOffset] =
991+
cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
992+
sf_out);
993+
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
994+
reinterpret_cast<uint64_t*>(out)[outOffset] =
995+
cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
996+
}
997+
}
998+
}
999+
}
1000+
}
1001+
asm volatile("griddepcontrol.launch_dependents;");
1002+
#endif
1003+
}
1004+
9201005
__global__ void nvfp4_block_scale_interleave_kernel(int numbatches, int numRows, int numCols,
9211006
uint8_t const* SFIn, uint8_t* SFOutput);
9221007
} // namespace kernels

csrc/nv_internal/tensorrt_llm/kernels/quantization.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ enum class FP4QuantizationSFLayout {
4141
LINEAR
4242
};
4343

44+
// This denotes the input and output data types of the block scale quantization.
45+
enum class BlockScaleQuantizationType {
46+
FP16_TO_FP4 = 0,
47+
FP8_TO_FP4 = 1,
48+
FP16_TO_MXFP8 = 2,
49+
};
50+
4451
#define PadUpFn(X, Y) ((X + Y - 1) / (Y) * (Y))
4552

4653
// totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed.
@@ -86,5 +93,10 @@ void invokeNVFP4BlockScaleInterleaveReverse(int b, int m, int n, uint8_t const*
8693
uint8_t* SFOutput, int multiProcessorCount,
8794
cudaStream_t stream = 0);
8895

96+
template <typename T>
97+
void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* output, int32_t* SFOuput,
98+
FP4QuantizationSFLayout layout, int multiProcessorCount,
99+
cudaStream_t stream = 0);
100+
89101
} // namespace kernels
90102
} // namespace tensorrt_llm
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/thop/fp8Quantize.h"
18+
19+
#include <ATen/cuda/EmptyTensor.h>
20+
21+
#include "cutlass/numeric_types.h"
22+
#include "pytorch_extension_utils.h"
23+
#include "tensorrt_llm/thop/thUtils.h"
24+
25+
namespace torch_ext {
26+
27+
// input: [M, K], fp32/fp16/bf16/fp8_quantized
28+
// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in
29+
// linear layout. See FP4QuantizationSFLayout enum for more details about the two layouts.
30+
// returns
31+
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor input, bool isSfSwizzledLayout) {
32+
CHECK_TH_CUDA(input);
33+
CHECK_CONTIGUOUS(input);
34+
35+
auto const& inputShape = input.sizes();
36+
auto const& rank = inputShape.size();
37+
38+
TORCH_CHECK(rank >= 2, "Input should be >=2D tensor.");
39+
int64_t m = 1;
40+
for (size_t i = 0; i < rank - 1; i++) {
41+
m *= inputShape[i];
42+
}
43+
auto const k = inputShape[rank - 1];
44+
int32_t const sfVecSize = 32;
45+
TORCH_CHECK(k % sfVecSize == 0);
46+
47+
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
48+
outputShape[rank - 1] = k;
49+
50+
at::Tensor valueFP8 =
51+
at::detail::empty_cuda(outputShape, at::ScalarType::Float8_e4m3fn, input.device(),
52+
/* stride */ std::nullopt);
53+
54+
int64_t SFSize = isSfSwizzledLayout
55+
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)
56+
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize);
57+
58+
at::Tensor scaleFP8SF = at::detail::empty_cuda({SFSize}, SF_DTYPE, input.device(),
59+
/* stride */ std::nullopt); // 1D tensor
60+
61+
const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
62+
63+
auto const layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
64+
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
65+
66+
#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \
67+
tensorrt_llm::kernels::invokeMxFP8Quantization<T>( \
68+
1, m, k, reinterpret_cast<T*>(input.data_ptr()), \
69+
reinterpret_cast<int64_t*>(valueFP8.data_ptr()), \
70+
reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, \
71+
at::cuda::getCurrentCUDAStream(input.get_device()));
72+
73+
if (input.scalar_type() == at::ScalarType::Half) {
74+
LAUNCH_MXFP8_QUANTIZE_KERNEL(half)
75+
} else if (input.scalar_type() == at::ScalarType::BFloat16) {
76+
#ifdef ENABLE_BF16
77+
LAUNCH_MXFP8_QUANTIZE_KERNEL(__nv_bfloat16)
78+
#else
79+
C10_THROW_ERROR(NotImplementedError,
80+
"BFloat16 must be enabled to quantize an bf16 tensor to mxfp8.");
81+
#endif
82+
} else {
83+
C10_THROW_ERROR(NotImplementedError,
84+
"mxfp8_quantize only supports input tensor with dtypes fp16/bf16.");
85+
}
86+
87+
#undef LAUNCH_MXFP8_QUANTIZE_KERNEL
88+
89+
return {valueFP8, scaleFP8SF};
90+
}
91+
92+
inline uint8_t float_to_ue8m0(float value) {
93+
if (value == 0.0f) {
94+
return 0x00;
95+
}
96+
constexpr uint32_t FP32_MANTISSA_BITS = 23;
97+
uint32_t val_u32 = *reinterpret_cast<uint32_t*>(&value);
98+
uint8_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
99+
uint32_t mantissa = val_u32 & 0x7FFFFF;
100+
// Round up exponent and deal with satfinite.
101+
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
102+
++exponent;
103+
}
104+
return exponent;
105+
}
106+
107+
// Used in tests to quantize mxe4m3 tensors on host.
108+
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize_host(at::Tensor x_fp32,
109+
bool is_sf_swizzled_layout) {
110+
int32_t const sf_vec_size = 32;
111+
CHECK_CPU_INPUT(x_fp32, c10::ScalarType::Float);
112+
auto data_shape = x_fp32.sizes();
113+
TORCH_CHECK(data_shape.size() == 2, "x_fp32 should be 2D tensor.");
114+
int num_tokens = data_shape[0];
115+
int hidden_dim = data_shape[1];
116+
int groups_per_hidden_dim = hidden_dim / sf_vec_size;
117+
118+
at::Tensor fp8_tensor = at::detail::empty_cpu({num_tokens, hidden_dim}, at::ScalarType::Byte,
119+
/* pinned */ true, at::MemoryFormat::Contiguous);
120+
int64_t sf_size =
121+
is_sf_swizzled_layout
122+
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(num_tokens, hidden_dim / sf_vec_size)
123+
: tensorrt_llm::computeFP4LinearLayoutSFSize(num_tokens, hidden_dim / sf_vec_size);
124+
at::Tensor scale_tensor =
125+
at::detail::empty_cpu({sf_size}, SF_DTYPE, /* pinned */ true, at::MemoryFormat::Contiguous);
126+
127+
tensorrt_llm::FP4QuantizationSFLayout layout =
128+
is_sf_swizzled_layout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
129+
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
130+
131+
for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) {
132+
for (int group = 0; group < groups_per_hidden_dim; ++group) {
133+
float* fp32_ptr = x_fp32.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size;
134+
uint8_t* fp8_ptr = fp8_tensor.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size;
135+
136+
uint8_t* scale_ue8m08sf_ptr = scale_tensor.data_ptr<uint8_t>();
137+
138+
float local_amax = 0.0f;
139+
for (int ki = 0; ki < sf_vec_size; ++ki) {
140+
local_amax = std::max(std::abs(fp32_ptr[ki]), local_amax);
141+
}
142+
143+
local_amax *= (1.f / 448.0f);
144+
145+
uint8_t scale_ue8m0 = float_to_ue8m0(local_amax);
146+
auto const inv_scale = (scale_ue8m0 == 0) ? 1 : exp2f(127 - static_cast<float>(scale_ue8m0));
147+
148+
scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0], groups_per_hidden_dim, layout)] =
149+
scale_ue8m0;
150+
151+
for (int ki = 0; ki < sf_vec_size; ++ki) {
152+
float const scaled_fp32_value = fp32_ptr[ki] * inv_scale;
153+
auto fp8_value = cutlass::float_e4m3_t{scaled_fp32_value};
154+
fp8_ptr[ki] = *reinterpret_cast<uint8_t*>(&fp8_value);
155+
}
156+
}
157+
}
158+
return std::make_tuple(fp8_tensor, scale_tensor);
159+
}
160+
161+
// Used in tests to dequantize mxe4m3 tensors on host.
162+
at::Tensor mxfp8_dequantize_host(at::Tensor value_e4m3, at::Tensor scale_ue8m08sf,
163+
bool is_sf_swizzled_layout) {
164+
int32_t const sf_vec_size = 32;
165+
CHECK_CPU_INPUT(value_e4m3, c10::ScalarType::Byte);
166+
CHECK_CPU_INPUT(scale_ue8m08sf, SF_DTYPE);
167+
auto data_shape = value_e4m3.sizes();
168+
auto scale_shape = scale_ue8m08sf.sizes();
169+
TORCH_CHECK(data_shape.size() == 2, "value_e4m3 should be 2D tensor.");
170+
TORCH_CHECK(scale_shape.size() == 1, "scale_ue8m08sf should be 1D tensor.");
171+
at::Tensor float_tensor =
172+
at::detail::empty_cpu({data_shape[0], data_shape[1]}, at::ScalarType::Float,
173+
/* pinned */ true, at::MemoryFormat::Contiguous);
174+
175+
int hidden_dim = data_shape[1];
176+
int groups_per_hidden_dim = hidden_dim / sf_vec_size;
177+
178+
tensorrt_llm::FP4QuantizationSFLayout layout =
179+
is_sf_swizzled_layout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
180+
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
181+
for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) {
182+
for (int group = 0; group < groups_per_hidden_dim; ++group) {
183+
float* float_ptr = float_tensor.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size;
184+
uint8_t* fp8_ptr = value_e4m3.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size;
185+
uint8_t* scale_ue8m08sf_ptr = scale_ue8m08sf.data_ptr<uint8_t>();
186+
uint8_t fp8_scale = scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0],
187+
groups_per_hidden_dim, layout)];
188+
189+
float scale_float;
190+
uint32_t scale_float_u32 = uint32_t(fp8_scale) << 23;
191+
memcpy(&scale_float, &scale_float_u32, sizeof(scale_float));
192+
193+
for (int ki = 0; ki < sf_vec_size; ++ki) {
194+
uint8_t fp8_u8_repr = fp8_ptr[ki];
195+
auto fp32 = static_cast<float>(*reinterpret_cast<cutlass::float_e4m3_t*>(&fp8_u8_repr));
196+
float value = fp32 * scale_float;
197+
float_ptr[ki] = value;
198+
}
199+
}
200+
}
201+
return float_tensor;
202+
}
203+
} // namespace torch_ext
204+
205+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
206+
m.def("mxfp8_dequantize_host", &torch_ext::mxfp8_dequantize_host);
207+
m.def("mxfp8_quantize_host", &torch_ext::mxfp8_quantize_host);
208+
m.def("mxfp8_quantize", &torch_ext::mxfp8_quantize);
209+
}

0 commit comments

Comments
 (0)