Skip to content

Commit e569912

Browse files
authored
feat: Add alignment in MxFP8Quantization (#1445)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Most of the code is from trtllm. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 160e4b8 commit e569912

File tree

7 files changed

+157
-43
lines changed

7 files changed

+157
-43
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const
7474
// MXFP8 Quantization
7575

7676
template <typename T>
77-
void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* output, int32_t* SFOuput,
78-
FP4QuantizationSFLayout layout, int multiProcessorCount,
79-
cudaStream_t stream) {
77+
void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int64_t* output,
78+
int32_t* SFOuput, FP4QuantizationSFLayout layout,
79+
int multiProcessorCount, cudaStream_t stream) {
8080
// Fixed SF_VEC_SIZE as 32
8181
static constexpr int SF_VEC_SIZE = 32;
8282

8383
// Grid, Block size.
8484
// Each thread converts 8 values.
85-
dim3 block(std::min(int(n / CVT_FP4_ELTS_PER_THREAD), 512));
85+
dim3 block(std::min(int(padded_n / CVT_FP4_ELTS_PER_THREAD), 512));
8686
// Get number of blocks per SM (assume we can fully utilize the SM).
8787
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
8888
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
@@ -101,7 +101,7 @@ void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* outpu
101101
cudaLaunchKernelEx(
102102
&config,
103103
quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
104-
m, n, input, nullptr, reinterpret_cast<uint32_t*>(output),
104+
m, n, padded_n, input, nullptr, reinterpret_cast<uint32_t*>(output),
105105
reinterpret_cast<uint32_t*>(SFOuput), layout);
106106
}
107107

@@ -163,7 +163,7 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);
163163
#endif
164164

165165
////////////////////////////////////////////////////////////////////////////////////////////////////
166-
// FP4 Quantization
166+
// FP4/MXFP8 Quantization
167167

168168
template <typename T, int SF_VEC_SIZE>
169169
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, int64_t* output,
@@ -355,9 +355,10 @@ template void invokeBatchedFP4Quantization<half, 16>(
355355
template void invokeBatchedFP4Quantization<half, 32>(
356356
int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput,
357357
bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream);
358-
template void invokeMxFP8Quantization<half>(int b, int m, int n, half const* input, int64_t* output,
359-
int32_t* SFOuput, FP4QuantizationSFLayout layout,
360-
int multiProcessorCount, cudaStream_t stream);
358+
template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input,
359+
int64_t* output, int32_t* SFOuput,
360+
FP4QuantizationSFLayout layout, int multiProcessorCount,
361+
cudaStream_t stream);
361362
#ifdef ENABLE_BF16
362363
template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input,
363364
float const* SFScale, int64_t* output,
@@ -379,7 +380,7 @@ template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>(
379380
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
380381
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
381382
cudaStream_t stream);
382-
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n,
383+
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n,
383384
__nv_bfloat16 const* input, int64_t* output,
384385
int32_t* SFOuput,
385386
FP4QuantizationSFLayout layout,

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -924,8 +924,8 @@ __launch_bounds__(512, 4) quantize_with_block_size(
924924
#else
925925
quantize_with_block_size(
926926
#endif
927-
int32_t numbatches, int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
928-
uint32_t* out, uint32_t* SFout, FP4QuantizationSFLayout layout) {
927+
int32_t numbatches, int32_t numRows, int32_t numCols, int32_t numPaddedCols, Type const* in,
928+
float const* SFScale, uint32_t* out, uint32_t* SFout, FP4QuantizationSFLayout layout) {
929929
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
930930

931931
// The elements per thread.
@@ -941,46 +941,59 @@ quantize_with_block_size(
941941
// Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)).
942942
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
943943

944-
int numPaddedRows = numRows;
945-
int numPaddedCols = numCols;
946-
if (layout == FP4QuantizationSFLayout::SWIZZLED_128x4) {
947-
// The number of padded rows considering 128x4 SF layout.
948-
numPaddedRows = PadUpFn(numRows, 128);
949-
numPaddedCols = PadUpFn(numCols, 4 * SF_VEC_SIZE);
950-
} else if (layout == FP4QuantizationSFLayout::SWIZZLED_8x4) {
951-
// The number of padded rows considering 8x4 SF layout.
952-
numPaddedRows = PadUpFn(numRows, 8);
953-
numPaddedCols = PadUpFn(numCols, 4 * SF_VEC_SIZE);
954-
}
944+
// Is it swizzled layout?
945+
bool isSfSwizzledLayout = layout == FP4QuantizationSFLayout::SWIZZLED_128x4 ||
946+
layout == FP4QuantizationSFLayout::SWIZZLED_8x4;
947+
948+
// The number of padded rows considering 128x4 SF layout.
949+
int numPaddedRowsForSf = isSfSwizzledLayout ? PadUpFn(numRows, 128) : numRows;
950+
int numColsForSf = isSfSwizzledLayout ? PadUpFn(numPaddedCols, 4 * SF_VEC_SIZE) : numPaddedCols;
955951

956-
// The number of threads in the column dimension
952+
// The number of threads in the column dimension.
953+
// Note that numCols/numPaddedCols/numColsForSf are guaranteed to be multiples of ELTS_PER_THREAD.
957954
int numColThreads = numCols / ELTS_PER_THREAD;
958955
int numPaddedColThreads = numPaddedCols / ELTS_PER_THREAD;
956+
int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD;
959957

960958
asm volatile("griddepcontrol.wait;");
961959
// Input tensor batch/row/col loops.
962-
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRows; rowIdx += gridDim.x) {
960+
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) {
963961
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
964-
for (int colIdx = threadIdx.x; colIdx < numPaddedColThreads; colIdx += blockDim.x) {
962+
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
965963
std::optional<int> optionalBatchIdx = batchIdx;
966964
std::optional<int> optionalNumRows = numRows;
967965

968966
// The SF output pointer.
969967
auto sf_out =
970968
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF, SF_VEC_SIZE>(
971-
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numCols, SFout, layout);
969+
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols, SFout, layout);
970+
971+
// The input tensor offset.
972+
int64_t inOffset =
973+
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
974+
int64_t outOffset =
975+
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
976+
977+
// Set the values to 0 of those are padded columns.
978+
if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) {
979+
// Dispatch the quantization kernel.
980+
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
981+
reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
982+
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
983+
quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
984+
reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
985+
}
986+
}
972987

973988
// Set the SF padding to 0.
974989
if (rowIdx >= numRows || colIdx >= numColThreads) {
990+
// Set the SF padding to 0.
975991
if (sf_out != nullptr) {
976992
sf_out[0] = 0x00;
977993
}
978994
} else {
979-
int64_t inOffset =
980-
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
995+
// Load the input vector.
981996
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
982-
// Get the output tensor offset as a packed vector.
983-
int64_t outOffset = inOffset;
984997

985998
// Dispatch the quantization kernel.
986999
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {

csrc/nv_internal/tensorrt_llm/kernels/quantization.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ void invokeNVFP4BlockScaleInterleaveReverse(int b, int m, int n, uint8_t const*
9494
cudaStream_t stream = 0);
9595

9696
template <typename T>
97-
void invokeMxFP8Quantization(int b, int m, int n, T const* input, int64_t* output, int32_t* SFOuput,
98-
FP4QuantizationSFLayout layout, int multiProcessorCount,
99-
cudaStream_t stream = 0);
97+
void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int64_t* output,
98+
int32_t* SFOuput, FP4QuantizationSFLayout layout,
99+
int multiProcessorCount, cudaStream_t stream = 0);
100100

101101
} // namespace kernels
102102
} // namespace tensorrt_llm

csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
#include <ATen/cuda/EmptyTensor.h>
2020

21+
#include <cstdint>
22+
2123
#include "cutlass/numeric_types.h"
2224
#include "pytorch_extension_utils.h"
2325
#include "tensorrt_llm/thop/thUtils.h"
@@ -27,8 +29,10 @@ namespace torch_ext {
2729
// input: [M, K], fp32/fp16/bf16/fp8_quantized
2830
// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in
2931
// linear layout. See FP4QuantizationSFLayout enum for more details about the two layouts.
32+
// alignment: sfVecSize
3033
// returns
31-
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor input, bool isSfSwizzledLayout) {
34+
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor input, bool isSfSwizzledLayout,
35+
int64_t alignment) {
3236
CHECK_TH_CUDA(input);
3337
CHECK_CONTIGUOUS(input);
3438

@@ -43,17 +47,18 @@ std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor input, bool isSfSwi
4347
auto const k = inputShape[rank - 1];
4448
int32_t const sfVecSize = 32;
4549
TORCH_CHECK(k % sfVecSize == 0);
50+
auto const padded_k = ((k + alignment - 1) / alignment) * alignment;
4651

4752
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
48-
outputShape[rank - 1] = k;
53+
outputShape[rank - 1] = padded_k;
4954

5055
at::Tensor valueFP8 =
5156
at::detail::empty_cuda(outputShape, at::ScalarType::Float8_e4m3fn, input.device(),
5257
/* stride */ std::nullopt);
5358

5459
int64_t SFSize = isSfSwizzledLayout
55-
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)
56-
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize);
60+
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, padded_k / sfVecSize)
61+
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, padded_k / sfVecSize);
5762

5863
at::Tensor scaleFP8SF = at::detail::empty_cuda({SFSize}, SF_DTYPE, input.device(),
5964
/* stride */ std::nullopt); // 1D tensor
@@ -65,7 +70,7 @@ std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor input, bool isSfSwi
6570

6671
#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \
6772
tensorrt_llm::kernels::invokeMxFP8Quantization<T>( \
68-
1, m, k, reinterpret_cast<T*>(input.data_ptr()), \
73+
1, m, k, padded_k, reinterpret_cast<T*>(input.data_ptr()), \
6974
reinterpret_cast<int64_t*>(valueFP8.data_ptr()), \
7075
reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, \
7176
at::cuda::getCurrentCUDAStream(input.get_device()));

csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ inline int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn,
6262
// input: [M, K], fp16/bf16_quantized
6363
// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in
6464
// linear layout. See FP4QuantizationSFLayout enum for more details about the two layouts.
65+
// alignment: sfVecSize
6566
// returns fp8_quantized and block_scale_factors.
6667
std::tuple<at::Tensor, at::Tensor> mxfp8_quantize(at::Tensor input,
67-
bool is_sf_swizzled_layout = true);
68+
bool is_sf_swizzled_layout = true,
69+
int64_t alignment = 32);
6870

6971
// x_fp32: [M, K], fp32_quantized (on the host)
7072
// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in

flashinfer/fp8_quantization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ def get_mxfp8_quantization_sm100_module():
4949
def mxfp8_quantize_sm100(
5050
input: torch.Tensor,
5151
is_sf_swizzled_layout: bool = True,
52+
alignment: int = 32,
5253
) -> Tuple[torch.Tensor, torch.Tensor]:
5354
"""Quantize input tensor to MxFP8 format.
5455
5556
Args:
5657
input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
5758
is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
58-
59+
alignment (int, optional): sfVecSize. Defaults to 32. Note that alignment is not used in the host kernel.
5960
Returns:
6061
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
6162
- Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3
@@ -70,6 +71,7 @@ def mxfp8_quantize_sm100(
7071
return module.mxfp8_quantize(
7172
input,
7273
is_sf_swizzled_layout,
74+
alignment,
7375
)
7476

7577
@register_fake_op("flashinfer::mxfp8_quantize_sm100")
@@ -126,6 +128,7 @@ def _fake_mxfp8_dequantize_host_sm100(
126128
def mxfp8_quantize(
127129
input: torch.Tensor,
128130
is_sf_swizzled_layout: bool = True,
131+
alignment: int = 32,
129132
) -> Tuple[torch.Tensor, torch.Tensor]:
130133
"""Quantize input tensor to MxFP8 format.
131134
@@ -135,7 +138,7 @@ def mxfp8_quantize(
135138
Args:
136139
input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
137140
is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
138-
141+
alignment (int, optional): sfVecSize. Defaults to 32.
139142
Returns:
140143
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
141144
- Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3
@@ -147,8 +150,8 @@ def mxfp8_quantize(
147150
x_q, sf = get_mxfp8_quantization_sm100_module().mxfp8_quantize_sm100(
148151
input,
149152
is_sf_swizzled_layout,
153+
alignment,
150154
)
151-
sf = sf.reshape((-1, input.shape[-1] // sf_vec_size))
152155
return x_q, sf
153156

154157

tests/test_fp8_quantize.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,95 @@ def check_accuracy(a, b, atol, rtol, percent):
5050
check_accuracy(a_pt, a, 8, 0, 0.999)
5151

5252

53+
def mxfp8_quantize_check_accuracy(a, b, atol, rtol, percent):
54+
if torch.any(torch.isnan(a)):
55+
raise Exception("NaN in a")
56+
if torch.any(torch.isnan(b)):
57+
raise Exception("NaN in b")
58+
assert a.shape == b.shape
59+
left = torch.abs(a - b)
60+
right = atol + rtol * torch.abs(b)
61+
count = torch.sum(left > right)
62+
mismatch_percent = count / a.numel()
63+
if mismatch_percent > 1 - percent:
64+
raise Exception(
65+
"Mismatch percentage is %f for rtol %f" % (mismatch_percent, rtol)
66+
)
67+
68+
69+
@pytest.mark.parametrize("m", [1, 2, 16, 1024])
70+
@pytest.mark.parametrize("k", [512, 1024])
71+
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
72+
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
73+
def test_mxfp8_quantize_torch_host(m, k, dtype, is_sf_swizzled_layout):
74+
torch.random.manual_seed(0)
75+
a = (torch.randn([m, k], dtype=torch.float) * 16).cpu().contiguous()
76+
77+
a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
78+
79+
a_pt = mxfp8_dequantize_host(
80+
a_fp8.view(torch.uint8), a_sf.view(torch.uint8), is_sf_swizzled_layout
81+
)
82+
83+
torch.cuda.synchronize()
84+
85+
mxfp8_quantize_check_accuracy(a_pt, a, 8, 0, 0.999)
86+
87+
88+
@pytest.mark.parametrize("m", [1, 2, 16, 1024])
89+
@pytest.mark.parametrize("k", [512, 1024])
90+
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
91+
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
92+
def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout):
93+
torch.random.manual_seed(0)
94+
a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous()
95+
96+
a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, 32)
97+
a_pt = mxfp8_dequantize_host(
98+
a_fp8.cpu().view(torch.uint8),
99+
a_sf.cpu().view(torch.uint8),
100+
is_sf_swizzled_layout,
101+
)
102+
103+
torch.cuda.synchronize()
104+
mxfp8_quantize_check_accuracy(
105+
a_pt.cpu().to(torch.float32), a.cpu().to(torch.float32), 8, 0, 0.999
106+
)
107+
108+
109+
@pytest.mark.parametrize("m", [1, 2, 16, 1024])
110+
@pytest.mark.parametrize("k", [1568])
111+
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
112+
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
113+
@pytest.mark.parametrize("alignment", [64, 128])
114+
def test_mxfp8_quantize_alignment_torch_device(
115+
m, k, dtype, is_sf_swizzled_layout, alignment
116+
):
117+
torch.random.manual_seed(0)
118+
a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous()
119+
padded_k = ((k + alignment - 1) // alignment) * alignment
120+
121+
# Quantize it on device.
122+
a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, alignment)
123+
assert a_fp8.shape[1] == padded_k
124+
125+
# Dequantize it on host.
126+
a_pt = mxfp8_dequantize_host(
127+
a_fp8.cpu().view(torch.uint8),
128+
a_sf.cpu().view(torch.uint8),
129+
is_sf_swizzled_layout,
130+
)
131+
132+
# Check if the bits of paddings are zero.
133+
paddings = a_fp8.view(torch.int8)[:, k:]
134+
assert torch.all(paddings == 0), "Paddings should be zero"
135+
136+
torch.cuda.synchronize()
137+
138+
mxfp8_quantize_check_accuracy(
139+
a_pt[:, :k].cpu().to(torch.float32), a.cpu().to(torch.float32), 8, 0, 0.999
140+
)
141+
142+
53143
if __name__ == "__main__":
54144
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)